diff --git a/alembic_db/versions/0005_download_manager.py b/alembic_db/versions/0005_download_manager.py new file mode 100644 index 000000000..6efffe1d8 --- /dev/null +++ b/alembic_db/versions/0005_download_manager.py @@ -0,0 +1,115 @@ +""" +Download manager schema. + +Adds the three tables that back the server-side model download manager +: transient job/queue state (``downloads`` + per-segment +``download_segments``) and one-API-key-per-host auth (``host_credentials``). + +Revision ID: 0005_download_manager +Revises: 0004_drop_tag_type +Create Date: 2026-06-27 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0005_download_manager" +down_revision = "0004_drop_tag_type" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "downloads", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("url", sa.Text(), nullable=False), + sa.Column("final_url", sa.Text(), nullable=True), + sa.Column("model_id", sa.String(length=1024), nullable=False), + sa.Column("dest_path", sa.Text(), nullable=False), + sa.Column("temp_path", sa.Text(), nullable=False), + sa.Column("status", sa.String(length=16), nullable=False), + sa.Column("priority", sa.Integer(), nullable=False, server_default="0"), + sa.Column("total_bytes", sa.BigInteger(), nullable=True), + sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("etag", sa.String(length=512), nullable=True), + sa.Column("last_modified", sa.String(length=128), nullable=True), + sa.Column( + "accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("expected_sha256", sa.String(length=64), nullable=True), + sa.Column("credential_id", sa.String(length=36), nullable=True), + sa.Column( + "allow_any_extension", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"), + sa.CheckConstraint( + "total_bytes IS NULL OR total_bytes >= 0", + name="ck_downloads_total_bytes_nonneg", + ), + ) + op.create_index("ix_downloads_status", "downloads", ["status"]) + op.create_index("ix_downloads_priority", "downloads", ["priority"]) + op.create_index("ix_downloads_model_id", "downloads", ["model_id"]) + + op.create_table( + "download_segments", + sa.Column( + "download_id", + sa.String(length=36), + sa.ForeignKey("downloads.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("idx", sa.Integer(), nullable=False), + sa.Column("start_offset", sa.BigInteger(), nullable=False), + sa.Column("end_offset", sa.BigInteger(), nullable=False), + sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"), + sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"), + sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"), + sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"), + ) + + op.create_table( + "host_credentials", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("host", sa.String(length=255), nullable=False), + sa.Column( + "match_subdomains", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column("label", sa.String(length=255), nullable=True), + sa.Column( + "auth_scheme", sa.String(length=16), nullable=False, server_default="bearer" + ), + sa.Column("header_name", sa.String(length=255), nullable=True), + sa.Column("query_param", sa.String(length=255), nullable=True), + sa.Column("secret", sa.Text(), nullable=False), + sa.Column("secret_last4", sa.String(length=4), nullable=True), + sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + ) + op.create_index( + "uq_host_credentials_host", "host_credentials", ["host"], unique=True + ) + + +def downgrade() -> None: + op.drop_index("uq_host_credentials_host", table_name="host_credentials") + op.drop_table("host_credentials") + + op.drop_table("download_segments") + + op.drop_index("ix_downloads_model_id", table_name="downloads") + op.drop_index("ix_downloads_priority", table_name="downloads") + op.drop_index("ix_downloads_status", table_name="downloads") + op.drop_table("downloads") diff --git a/app/database/db.py b/app/database/db.py index 0aab09a49..2b09b8147 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -4,7 +4,11 @@ import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message from filelock import FileLock, Timeout -from comfy.cli_args import args +# NOTE: import the module (not `from ... import args`) so we always read the +# live `args` object. Tests reload `comfy.cli_args`, which replaces the module +# global; a bound `args` reference would go stale and point at the default +# database URL instead of the one configured for the test. +import comfy.cli_args _DB_AVAILABLE = False Session = None @@ -21,6 +25,7 @@ try: from app.database.models import Base import app.assets.database.models # noqa: F401 — register models with Base.metadata + import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: @@ -57,13 +62,13 @@ def get_alembic_config(): config = Config(config_path) config.set_main_option("script_location", scripts_path) - config.set_main_option("sqlalchemy.url", args.database_url) + config.set_main_option("sqlalchemy.url", comfy.cli_args.args.database_url) return config def get_db_path(): - url = args.database_url + url = comfy.cli_args.args.database_url if url.startswith("sqlite:///"): return url.split("///")[1] else: @@ -97,7 +102,7 @@ def _is_memory_db(db_url): def init_db(): - db_url = args.database_url + db_url = comfy.cli_args.args.database_url logging.debug(f"Database URL: {db_url}") if _is_memory_db(db_url): diff --git a/app/model_downloader/api/routes.py b/app/model_downloader/api/routes.py new file mode 100644 index 000000000..0a08edb41 --- /dev/null +++ b/app/model_downloader/api/routes.py @@ -0,0 +1,220 @@ +"""aiohttp routes for the download manager. + +Endpoint surface (all under ``/api/download``), mirroring the response +envelope used by ``app/assets/api/routes.py``: + + POST /api/download/enqueue + GET /api/download + POST /api/download/availability + POST /api/download/clear + POST /api/download/credentials + GET /api/download/credentials + GET /api/download/credentials/{id} + DELETE /api/download/credentials/{id} + GET /api/download/{id} + DELETE /api/download/{id} + POST /api/download/{id}/pause + POST /api/download/{id}/resume + POST /api/download/{id}/cancel + POST /api/download/{id}/priority + +Note on ordering: the static ``credentials`` routes are registered before the +dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not +captured as ``id == "credentials"``. +""" + +from __future__ import annotations + +import json + +from aiohttp import web +from pydantic import BaseModel, ValidationError + +from app.model_downloader.api import schemas_in, schemas_out +from app.model_downloader.credentials.store import ( + CREDENTIAL_STORE, + CredentialValidationError, +) +from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + +ROUTES = web.RouteTableDef() + + +def register_routes(app: web.Application) -> None: + """Wire the download-manager routes into the running aiohttp app.""" + app.add_routes(ROUTES) + + +# ----- envelope helpers (same shape as app/assets/api/routes.py) ----- + + +def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _ok(payload, status: int = 200) -> web.Response: + return web.json_response(payload, status=status) + + +async def _parse(request: web.Request, model: type[BaseModel]): + try: + raw = await request.json() + except json.JSONDecodeError: + return _error(400, "INVALID_JSON", "Request body must be valid JSON.") + try: + return model.model_validate(raw) + except ValidationError as ve: + return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())}) + + +def _from_download_error(e: DownloadError) -> web.Response: + return _error(e.http_status, e.code, e.message) + + +# ----- downloads: collection + enqueue + availability ----- + + +@ROUTES.post("/api/download/enqueue") +async def enqueue(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.EnqueueRequest) + if isinstance(parsed, web.Response): + return parsed + try: + download_id = await DOWNLOAD_MANAGER.enqueue( + parsed.url, + parsed.model_id, + priority=parsed.priority, + expected_sha256=parsed.expected_sha256, + allow_any_extension=parsed.allow_any_extension, + credential_id=parsed.credential_id, + ) + except DownloadError as e: + return _from_download_error(e) + return _ok({"download_id": download_id, "accepted": True}, status=202) + + +@ROUTES.get("/api/download") +async def list_downloads(request: web.Request) -> web.Response: + return _ok({"downloads": await DOWNLOAD_MANAGER.list()}) + + +@ROUTES.post("/api/download/availability") +async def availability(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.AvailabilityRequest) + if isinstance(parsed, web.Response): + return parsed + return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)}) + + +@ROUTES.post("/api/download/clear") +async def clear(request: web.Request) -> web.Response: + deleted = await DOWNLOAD_MANAGER.clear() + return _ok({"deleted": deleted}) + + +# ----- credentials (secrets are write-only) — must precede /{id} ----- + + +@ROUTES.post("/api/download/credentials") +async def upsert_credential(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.CredentialUpsertRequest) + if isinstance(parsed, web.Response): + return parsed + try: + view = await CREDENTIAL_STORE.upsert( + parsed.host, + parsed.secret, + auth_scheme=parsed.auth_scheme, + header_name=parsed.header_name, + query_param=parsed.query_param, + label=parsed.label, + match_subdomains=parsed.match_subdomains, + enabled=parsed.enabled, + ) + except CredentialValidationError as e: + return _error(400, "INVALID_CREDENTIAL", str(e)) + return _ok(schemas_out.credential_to_dict(view), status=201) + + +@ROUTES.get("/api/download/credentials") +async def list_credentials(request: web.Request) -> web.Response: + views = await CREDENTIAL_STORE.list() + return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]}) + + +@ROUTES.get("/api/download/credentials/{id}") +async def get_credential(request: web.Request) -> web.Response: + view = await CREDENTIAL_STORE.get(request.match_info["id"]) + if view is None: + return _error(404, "NOT_FOUND", "No such credential.") + return _ok(schemas_out.credential_to_dict(view)) + + +@ROUTES.delete("/api/download/credentials/{id}") +async def delete_credential(request: web.Request) -> web.Response: + deleted = await CREDENTIAL_STORE.delete(request.match_info["id"]) + if not deleted: + return _error(404, "NOT_FOUND", "No such credential.") + return _ok({"deleted": True}) + + +# ----- single download by id (dynamic; registered last) ----- + + +@ROUTES.get("/api/download/{id}") +async def get_download(request: web.Request) -> web.Response: + view = await DOWNLOAD_MANAGER.status(request.match_info["id"]) + if view is None: + return _error(404, "NOT_FOUND", "No such download.") + return _ok(view) + + +@ROUTES.delete("/api/download/{id}") +async def delete_download(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.delete(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"deleted": True}) + + +@ROUTES.post("/api/download/{id}/pause") +async def pause(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.pause(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) + + +@ROUTES.post("/api/download/{id}/resume") +async def resume(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.resume(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) + + +@ROUTES.post("/api/download/{id}/cancel") +async def cancel(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.cancel(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) + + +@ROUTES.post("/api/download/{id}/priority") +async def set_priority(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.PriorityRequest) + if isinstance(parsed, web.Response): + return parsed + try: + await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) diff --git a/app/model_downloader/api/schemas_in.py b/app/model_downloader/api/schemas_in.py new file mode 100644 index 000000000..c2db2feb4 --- /dev/null +++ b/app/model_downloader/api/schemas_in.py @@ -0,0 +1,51 @@ +"""Request schemas for the download manager API. + +Pydantic enforces shape at the boundary; handlers operate only on validated +values past that point. +""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Field + +from app.model_downloader.constants import AUTH_SCHEME_BEARER + + +class EnqueueRequest(BaseModel): + url: str + model_id: str + priority: int = 0 + expected_sha256: Optional[str] = None + allow_any_extension: bool = False + credential_id: Optional[str] = None + + +class PriorityRequest(BaseModel): + priority: int + + +class AvailabilityRequest(BaseModel): + """``{model_id: url}`` — the URLs declared in the workflow JSON.""" + + models: dict[str, str] = Field(default_factory=dict) + + +class CredentialUpsertRequest(BaseModel): + host: str + secret: str + auth_scheme: str = AUTH_SCHEME_BEARER + header_name: Optional[str] = None + query_param: Optional[str] = None + label: Optional[str] = None + match_subdomains: bool = False + enabled: bool = True + + +__all__ = [ + "EnqueueRequest", + "PriorityRequest", + "AvailabilityRequest", + "CredentialUpsertRequest", +] diff --git a/app/model_downloader/api/schemas_out.py b/app/model_downloader/api/schemas_out.py new file mode 100644 index 000000000..59ace6430 --- /dev/null +++ b/app/model_downloader/api/schemas_out.py @@ -0,0 +1,26 @@ +"""Response helpers for the download manager API. + +The download/status read models are plain dicts produced by the manager. This +module only needs to mask credentials for output (the secret is never returned). +""" + +from __future__ import annotations + +from app.model_downloader.credentials.store import CredentialView + + +def credential_to_dict(view: CredentialView) -> dict: + """API-safe credential representation — never includes the secret.""" + return { + "id": view.id, + "host": view.host, + "auth_scheme": view.auth_scheme, + "header_name": view.header_name, + "query_param": view.query_param, + "label": view.label, + "match_subdomains": view.match_subdomains, + "enabled": view.enabled, + "secret_last4": view.secret_last4, + "created_at": view.created_at, + "updated_at": view.updated_at, + } diff --git a/app/model_downloader/constants.py b/app/model_downloader/constants.py new file mode 100644 index 000000000..6430b1288 --- /dev/null +++ b/app/model_downloader/constants.py @@ -0,0 +1,47 @@ +"""Shared constants for the download manager. + +Status values are persisted as TEXT in the ``downloads`` table; keep them +stable. The lifecycle is: + + queued -> active -> verifying -> completed + | |-> paused -> (resume) -> active + | |-> failed (network, retryable) -> queued (backoff) + |-> cancelled +""" + +from __future__ import annotations + +# Auth schemes for HostCredential +AUTH_SCHEME_BEARER = "bearer" +AUTH_SCHEME_HEADER = "header" +AUTH_SCHEME_QUERY = "query" +AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY) + +# Hosts for which a bearer token can be sourced from the environment when no +# stored credential matches. Values are the env var names to try, in order. +# Only consulted during auto-resolve for an exact host match over https, so the +# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to +# a CDN host). Kept here so the host->env-var mapping lives in one place. +ENV_TOKEN_HOSTS = { + "huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"), +} + + +class DownloadStatus: + QUEUED = "queued" + ACTIVE = "active" + PAUSED = "paused" + VERIFYING = "verifying" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + #: States from which a worker is doing (or about to do) network I/O. + LIVE = (QUEUED, ACTIVE, VERIFYING) + #: Terminal states — the job will not transition again on its own. + TERMINAL = (COMPLETED, FAILED, CANCELLED) + + +# Default temp-file suffix. Distinctive so the startup orphan sweep only +# removes files THIS subsystem created, never unrelated *.tmp files. +TMP_SUFFIX = ".comfy-download.part" diff --git a/app/model_downloader/credentials/resolver.py b/app/model_downloader/credentials/resolver.py new file mode 100644 index 000000000..b85fccc89 --- /dev/null +++ b/app/model_downloader/credentials/resolver.py @@ -0,0 +1,111 @@ +"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2). + +The critical rule: a credential is only ever attached when *the current hop's +host* matches a stored credential, and only over https. This is recomputed +from scratch on every redirect hop, so a token bound to ``huggingface.co`` is +silently dropped when the request is redirected to a presigned CDN host — +which is exactly what these hubs expect. +""" + +from __future__ import annotations + +import asyncio +import os +from dataclasses import dataclass, field +from typing import Optional +from urllib.parse import urlencode, urlsplit, urlunsplit + +from app.model_downloader.constants import ( + AUTH_SCHEME_BEARER, + AUTH_SCHEME_HEADER, + AUTH_SCHEME_QUERY, + ENV_TOKEN_HOSTS, +) +from app.model_downloader.credentials.store import normalize_host +from app.model_downloader.database import queries +from app.model_downloader.database.models import HostCredential + + +@dataclass +class RequestAuth: + """How to modify a single request to carry a credential.""" + + headers: dict[str, str] = field(default_factory=dict) + query: dict[str, str] = field(default_factory=dict) + + def apply_to_url(self, url: str) -> str: + if not self.query: + return url + parts = urlsplit(url) + # Append only the credential params, leaving the original query string + # (including any repeated keys and existing encoding) untouched. + creds = urlencode(self.query) + query = f"{parts.query}&{creds}" if parts.query else creds + return urlunsplit(parts._replace(query=query)) + + +def _matches(cred: HostCredential, hop_host: str) -> bool: + cred_host = cred.host + if hop_host == cred_host: + return True + if cred.match_subdomains: + # Label-boundary suffix: api.example.com matches example.com, but + # evil-example.com does NOT. + return hop_host.endswith("." + cred_host) + return False + + +def _build_auth(cred: HostCredential) -> RequestAuth: + if cred.auth_scheme == AUTH_SCHEME_BEARER: + return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"}) + if cred.auth_scheme == AUTH_SCHEME_HEADER: + name = cred.header_name or "Authorization" + return RequestAuth(headers={name: cred.secret}) + if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param: + return RequestAuth(query={cred.query_param: cred.secret}) + return RequestAuth() + + +def _resolve_sync( + host: str, scheme: str, explicit_credential_id: Optional[str] +) -> Optional[RequestAuth]: + # Never attach a secret over a non-https hop (PRD section 9.4.2). + if scheme.lower() != "https": + return None + hop_host = normalize_host(host) + if not hop_host: + return None + + if explicit_credential_id is not None: + cred = queries.get_credential(explicit_credential_id) + # An explicit credential is still subject to the per-hop host check — + # it is not forced onto a non-matching host. + if cred is None or not cred.enabled or not _matches(cred, hop_host): + return None + return _build_auth(cred) + + # Auto-resolve: exact host first, then any subdomain-matching credential. + cred = queries.get_credential_by_host(hop_host) + if cred is not None and cred.enabled: + return _build_auth(cred) + for sub in queries.list_subdomain_credentials(): + if sub.enabled and _matches(sub, hop_host): + return _build_auth(sub) + + # Env fallback: only for an exact host match, and only after the DB lookups + # miss, so a user-set credential always takes precedence. The token is never + # persisted; it is read fresh from the environment on each hop. + for var in ENV_TOKEN_HOSTS.get(hop_host, ()): + token = os.environ.get(var) + if token: + return RequestAuth(headers={"Authorization": f"Bearer {token}"}) + return None + + +async def resolve_auth_for_hop( + host: str, scheme: str, *, explicit_credential_id: Optional[str] = None +) -> Optional[RequestAuth]: + """Resolve the credential (if any) to attach for one request hop.""" + return await asyncio.to_thread( + _resolve_sync, host, scheme, explicit_credential_id + ) diff --git a/app/model_downloader/credentials/store.py b/app/model_downloader/credentials/store.py new file mode 100644 index 000000000..ac82d24e3 --- /dev/null +++ b/app/model_downloader/credentials/store.py @@ -0,0 +1,141 @@ +"""The credential store: one API key per host. + +Secrets are write-only over the API — :class:`CredentialView` carries only +masked metadata (``secret_last4`` + scheme + label), never the secret itself. +At-rest protection for v1 is filesystem permissions on the shared DB (the DB +is the trust boundary); encryption-at-rest is a noted future seam. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Optional +from urllib.parse import urlsplit + +from app.model_downloader.constants import ( + AUTH_SCHEME_BEARER, + AUTH_SCHEME_HEADER, + AUTH_SCHEME_QUERY, + AUTH_SCHEMES, +) +from app.model_downloader.database import queries +from app.model_downloader.database.models import HostCredential + + +def normalize_host(host: str) -> str: + """Lowercase, strip port, IDNA-encode.""" + if not host: + return "" + host = host.strip() + if "://" in host: # a full URL was pasted — extract just the host + host = urlsplit(host).hostname or "" + host = host.lower() + if host.startswith("[") and "]" in host: # bracketed IPv6 literal + host = host[1 : host.index("]")] + elif host.count(":") == 1: # host:port (not IPv6) + host = host.split(":", 1)[0] + try: + host = host.encode("idna").decode("ascii") + except (UnicodeError, ValueError): + pass + return host + + +@dataclass(frozen=True) +class CredentialView: + """Masked, API-safe view of a credential — never includes the secret.""" + + id: str + host: str + auth_scheme: str + header_name: Optional[str] + query_param: Optional[str] + label: Optional[str] + match_subdomains: bool + enabled: bool + secret_last4: Optional[str] + created_at: int + updated_at: int + + +def _to_view(row: HostCredential) -> CredentialView: + return CredentialView( + id=row.id, + host=row.host, + auth_scheme=row.auth_scheme, + header_name=row.header_name, + query_param=row.query_param, + label=row.label, + match_subdomains=row.match_subdomains, + enabled=row.enabled, + secret_last4=row.secret_last4, + created_at=row.created_at, + updated_at=row.updated_at, + ) + + +class CredentialValidationError(ValueError): + """A credential upsert had inconsistent fields.""" + + +class CredentialStore: + """Async facade over the ``host_credentials`` table. + + DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``. + """ + + async def upsert( + self, + host: str, + secret: str, + *, + auth_scheme: str = AUTH_SCHEME_BEARER, + header_name: Optional[str] = None, + query_param: Optional[str] = None, + label: Optional[str] = None, + match_subdomains: bool = False, + enabled: bool = True, + ) -> CredentialView: + host = normalize_host(host) + if not host: + raise CredentialValidationError("host is required") + if not secret: + raise CredentialValidationError("secret is required") + if auth_scheme not in AUTH_SCHEMES: + raise CredentialValidationError( + f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}" + ) + if auth_scheme == AUTH_SCHEME_HEADER and not header_name: + header_name = "Authorization" + if auth_scheme == AUTH_SCHEME_QUERY and not query_param: + raise CredentialValidationError( + "query_param is required when auth_scheme='query'" + ) + values = { + "host": host, + "secret": secret, + "secret_last4": secret[-4:] if len(secret) > 4 else None, + "auth_scheme": auth_scheme, + "header_name": header_name, + "query_param": query_param, + "label": label, + "match_subdomains": match_subdomains, + "enabled": enabled, + } + row = await asyncio.to_thread(queries.upsert_credential, values) + return _to_view(row) + + async def list(self) -> list[CredentialView]: + rows = await asyncio.to_thread(queries.list_credentials) + return [_to_view(r) for r in rows] + + async def get(self, credential_id: str) -> Optional[CredentialView]: + row = await asyncio.to_thread(queries.get_credential, credential_id) + return _to_view(row) if row is not None else None + + async def delete(self, credential_id: str) -> bool: + return await asyncio.to_thread(queries.delete_credential, credential_id) + + +CREDENTIAL_STORE = CredentialStore() diff --git a/app/model_downloader/database/models.py b/app/model_downloader/database/models.py new file mode 100644 index 000000000..546c8ba0c --- /dev/null +++ b/app/model_downloader/database/models.py @@ -0,0 +1,173 @@ +"""SQLAlchemy models for the download manager. + +Three tables: + +- ``downloads`` one row per requested file (job + queue state). +- ``download_segments`` per-segment byte progress, for segmented resume. +- ``host_credentials`` one API key per host, reused across downloads. + +On completion a finished file is registered into the assets catalog; +``downloads`` is kept only as job history. +""" + +from __future__ import annotations + +import time +import uuid + +from sqlalchemy import ( + BigInteger, + Boolean, + CheckConstraint, + ForeignKey, + Index, + Integer, + String, + Text, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database.models import Base + + +def _uuid() -> str: + return str(uuid.uuid4()) + + +def _now() -> int: + return int(time.time()) + + +class Download(Base): + __tablename__ = "downloads" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + # Original requested URL and the final URL after validated redirects. + url: Mapped[str] = mapped_column(Text, nullable=False) + final_url: Mapped[str | None] = mapped_column(Text, nullable=True) + # Canonical "/" identifier (resolved via folder_paths). + model_id: Mapped[str] = mapped_column(String(1024), nullable=False) + # Final on-disk location and the .part write target. + dest_path: Mapped[str] = mapped_column(Text, nullable=False) + temp_path: Mapped[str] = mapped_column(Text, nullable=False) + + status: Mapped[str] = mapped_column(String(16), nullable=False) + priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + + etag: Mapped[str | None] = mapped_column(String(512), nullable=True) + last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + # Optional hub-provided checksum to verify against (NOT the dedup key). + expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True) + + # Explicit credential override; otherwise auto-resolved by host. + # RESTRICT keeps a credential from being deleted while a download references it. + credential_id: Mapped[str | None] = mapped_column( + String(36), + ForeignKey("host_credentials.id", ondelete="RESTRICT"), + nullable=True, + ) + allow_any_extension: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + # How many retryable failures we have seen (for backoff capping). + attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + error: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now) + updated_at: Mapped[int] = mapped_column( + BigInteger, nullable=False, default=_now, onupdate=_now + ) + + segments: Mapped[list[DownloadSegment]] = relationship( + "DownloadSegment", + back_populates="download", + cascade="all,delete-orphan", + passive_deletes=True, + order_by="DownloadSegment.idx", + ) + + credential: Mapped[HostCredential | None] = relationship( + "HostCredential", back_populates="downloads" + ) + + __table_args__ = ( + Index("ix_downloads_status", "status"), + Index("ix_downloads_priority", "priority"), + Index("ix_downloads_model_id", "model_id"), + CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"), + CheckConstraint( + "total_bytes IS NULL OR total_bytes >= 0", + name="ck_downloads_total_bytes_nonneg", + ), + ) + + def __repr__(self) -> str: + return f"" + + +class DownloadSegment(Base): + __tablename__ = "download_segments" + + download_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("downloads.id", ondelete="CASCADE"), + primary_key=True, + ) + idx: Mapped[int] = mapped_column(Integer, primary_key=True) + start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False) + end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False) + bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + + download: Mapped[Download] = relationship("Download", back_populates="segments") + + __table_args__ = ( + CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"), + CheckConstraint("end_offset >= start_offset", name="ck_segments_range"), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class HostCredential(Base): + __tablename__ = "host_credentials" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + # Normalized lowercase hostname, e.g. "civitai.com". + host: Mapped[str] = mapped_column(String(255), nullable=False) + match_subdomains: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + label: Mapped[str | None] = mapped_column(String(255), nullable=True) + auth_scheme: Mapped[str] = mapped_column( + String(16), nullable=False, default="bearer" + ) + header_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + query_param: Mapped[str | None] = mapped_column(String(255), nullable=True) + # The API key itself. Write-only over the API; never returned. See PRD 9.4.4. + secret: Mapped[str] = mapped_column(Text, nullable=False) + secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True) + enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now) + updated_at: Mapped[int] = mapped_column( + BigInteger, nullable=False, default=_now, onupdate=_now + ) + + downloads: Mapped[list[Download]] = relationship( + "Download", back_populates="credential" + ) + + __table_args__ = ( + Index("uq_host_credentials_host", "host", unique=True), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/model_downloader/database/queries.py b/app/model_downloader/database/queries.py new file mode 100644 index 000000000..c71a234e1 --- /dev/null +++ b/app/model_downloader/database/queries.py @@ -0,0 +1,272 @@ +"""Synchronous DB access for the download manager. + +All functions open their own short-lived session via ``create_session`` and +commit before returning, mirroring ``app/assets`` usage. They are blocking +(SQLite) and should be called from async code through ``asyncio.to_thread``. +""" + +from __future__ import annotations + +import time +from typing import Optional + +from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError + +from app.database.db import create_session +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database.models import ( + Download, + DownloadSegment, + HostCredential, +) + + +# ----- downloads ----- + + +def insert_download(values: dict) -> None: + with create_session() as session: + session.add(Download(**values)) + session.commit() + + +def get_download(download_id: str) -> Optional[Download]: + with create_session() as session: + row = session.get(Download, download_id) + if row is not None: + session.expunge_all() + return row + + +def list_downloads() -> list[Download]: + with create_session() as session: + rows = list( + session.execute( + select(Download).order_by(Download.created_at.desc()) + ).scalars() + ) + session.expunge_all() + return rows + + +def list_segments(download_id: str) -> list[DownloadSegment]: + with create_session() as session: + rows = list( + session.execute( + select(DownloadSegment) + .where(DownloadSegment.download_id == download_id) + .order_by(DownloadSegment.idx) + ).scalars() + ) + session.expunge_all() + return rows + + +def update_download(download_id: str, **fields) -> None: + if not fields: + return + fields.setdefault("updated_at", int(time.time())) + with create_session() as session: + row = session.get(Download, download_id) + if row is None: + return + for key, value in fields.items(): + setattr(row, key, value) + session.commit() + + +def delete_download(download_id: str) -> None: + with create_session() as session: + row = session.get(Download, download_id) + if row is not None: + session.delete(row) + session.commit() + + +def delete_downloads(download_ids: list[str]) -> int: + """Delete many downloads in one transaction; returns the number removed. + + Uses a bulk ``DELETE ... WHERE id IN (...)``. Segment rows are removed by + the ``ON DELETE CASCADE`` foreign key (SQLite ``PRAGMA foreign_keys=ON`` is + set in ``app/database/db.py``), so this stays consistent without loading the + ORM relationship. + """ + if not download_ids: + return 0 + with create_session() as session: + result = session.execute( + delete(Download).where(Download.id.in_(download_ids)) + ) + session.commit() + return result.rowcount or 0 + + +def replace_segments(download_id: str, segments: list[dict]) -> None: + """Atomically replace the segment plan for a download.""" + with create_session() as session: + session.query(DownloadSegment).filter( + DownloadSegment.download_id == download_id + ).delete() + for seg in segments: + session.add(DownloadSegment(download_id=download_id, **seg)) + session.commit() + + +def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None: + with create_session() as session: + row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx}) + if row is None: + return + row.bytes_done = bytes_done + session.commit() + + +def list_queued_downloads() -> list[Download]: + """Queued rows ordered for admission (priority desc, then FIFO).""" + with create_session() as session: + rows = list( + session.execute( + select(Download) + .where(Download.status == DownloadStatus.QUEUED) + .order_by(Download.priority.desc(), Download.created_at.asc()) + ).scalars() + ) + session.expunge_all() + return rows + + +def reconcile_live_downloads() -> list[Download]: + """Reset any ``active``/``verifying`` rows left by a previous run. + + On a clean restart there can be no live worker, so anything still marked + live is stale. Move it back to ``queued`` (offsets are preserved on the + segment rows) so the scheduler re-admits it. Returns the rows that should + be re-queued by the scheduler (queued + paused). + """ + with create_session() as session: + stale = list( + session.execute( + select(Download).where( + Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING]) + ) + ).scalars() + ) + now = int(time.time()) + for row in stale: + row.status = DownloadStatus.QUEUED + row.updated_at = now + session.commit() + + resumable = list( + session.execute( + select(Download) + .where(Download.status == DownloadStatus.QUEUED) + .order_by(Download.priority.desc(), Download.created_at.asc()) + ).scalars() + ) + session.expunge_all() + return resumable + + +# ----- host credentials ----- + + +def get_credential(credential_id: str) -> Optional[HostCredential]: + with create_session() as session: + row = session.get(HostCredential, credential_id) + if row is not None: + session.expunge_all() + return row + + +def get_credential_by_host(host: str) -> Optional[HostCredential]: + with create_session() as session: + row = ( + session.execute( + select(HostCredential).where(HostCredential.host == host).limit(1) + ) + .scalars() + .first() + ) + if row is not None: + session.expunge_all() + return row + + +def list_credentials() -> list[HostCredential]: + with create_session() as session: + rows = list( + session.execute( + select(HostCredential).order_by(HostCredential.host) + ).scalars() + ) + session.expunge_all() + return rows + + +def list_subdomain_credentials() -> list[HostCredential]: + """Credentials that opted into subdomain matching, for suffix checks.""" + with create_session() as session: + rows = list( + session.execute( + select(HostCredential).where(HostCredential.match_subdomains.is_(True)) + ).scalars() + ) + session.expunge_all() + return rows + + +def upsert_credential(values: dict) -> HostCredential: + """Insert or update a credential keyed by ``host``. + + Callers can target the same host concurrently (each runs in its own + short-lived session on a separate connection), so the read-then-write here + can race: two callers both see no existing row and both attempt an insert. + The ``host`` column is uniquely indexed, so the loser's insert raises + ``IntegrityError``. We recover by rolling back and retrying, at which point + the now-committed row is found and updated in place, letting concurrent + calls converge instead of failing or creating duplicates. + """ + host = values["host"] + now = int(time.time()) + last_error: IntegrityError | None = None + for _ in range(2): + with create_session() as session: + row = ( + session.execute( + select(HostCredential).where(HostCredential.host == host).limit(1) + ) + .scalars() + .first() + ) + if row is None: + row = HostCredential(**values) + row.created_at = now + row.updated_at = now + session.add(row) + else: + for key, value in values.items(): + setattr(row, key, value) + row.updated_at = now + try: + session.commit() + except IntegrityError as exc: + session.rollback() + last_error = exc + continue + session.refresh(row) + session.expunge(row) + return row + assert last_error is not None + raise last_error + + +def delete_credential(credential_id: str) -> bool: + with create_session() as session: + row = session.get(HostCredential, credential_id) + if row is None: + return False + session.delete(row) + session.commit() + return True diff --git a/app/model_downloader/engine/job.py b/app/model_downloader/engine/job.py new file mode 100644 index 000000000..c0f5e9c2d --- /dev/null +++ b/app/model_downloader/engine/job.py @@ -0,0 +1,612 @@ +"""The per-download worker. + +One :class:`DownloadJob` drives a single file from probe to verified, cataloged +completion. It supports cooperative pause / resume / cancel, segmented +multi-connection transfer with positioned writes, and a verification gate +(size + structural + optional sha256) before the atomic rename into place. + +Control is cooperative: external callers flip ``_control`` via +:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between +chunks and raise, which unwinds cleanly and persists resume offsets. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Callable, Optional + +from comfy.cli_args import args +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.engine.planner import ( + effective_segment_count, + plan_segments, +) +from app.model_downloader.engine.writer import FileWriter +from app.model_downloader.net.http import open_validated, redact_url +from app.model_downloader.net.probe import gated_error_message, probe +from app.model_downloader.verify import checksum, dedup, structural + +_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504} +_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists + + +class Paused(Exception): + pass + + +class Cancelled(Exception): + pass + + +class RemoteChanged(Exception): + """The remote file changed under a resume (got 200 where 206 expected).""" + + +class RetryableError(Exception): + pass + + +class FatalError(Exception): + """Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc.""" + + +@dataclass +class SegmentRuntime: + idx: int + start: int + end: int # inclusive; may be -1 for unknown-size single stream + bytes_done: int = 0 + + @property + def length(self) -> int: + return self.end - self.start + 1 + + +@dataclass +class RuntimeState: + download_id: str + model_id: str + url: str + priority: int + status: str + total_bytes: Optional[int] = None + bytes_done: int = 0 + error: Optional[str] = None + segments: list[SegmentRuntime] = field(default_factory=list) + started_at: float = field(default_factory=time.monotonic) + _last_bytes: int = 0 + _last_time: float = field(default_factory=time.monotonic) + speed_bps: float = 0.0 + + @property + def progress(self) -> Optional[float]: + if not self.total_bytes: + return None + return min(1.0, self.bytes_done / self.total_bytes) + + @property + def eta_seconds(self) -> Optional[float]: + if not self.total_bytes or self.speed_bps <= 0: + return None + remaining = max(0, self.total_bytes - self.bytes_done) + return remaining / self.speed_bps + + +@dataclass +class JobSpec: + download_id: str + url: str + model_id: str + dest_path: str + temp_path: str + priority: int = 0 + credential_id: Optional[str] = None + expected_sha256: Optional[str] = None + allow_any_extension: bool = False + etag: Optional[str] = None + attempts: int = 0 + + +class DownloadJob: + def __init__( + self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None + ) -> None: + self.spec = spec + self._notify = notify_cb + self._control = "run" # run | pause | cancel + self.state = RuntimeState( + download_id=spec.download_id, + model_id=spec.model_id, + url=spec.url, + priority=spec.priority, + status=DownloadStatus.QUEUED, + ) + self._writer: Optional[FileWriter] = None + self._etag: Optional[str] = spec.etag + self._last_persist = 0.0 + + # ----- external control ----- + + def request_pause(self) -> None: + if self._control == "run": + self._control = "pause" + + def request_cancel(self) -> None: + self._control = "cancel" + + def _check_control(self) -> None: + if self._control == "cancel": + raise Cancelled() + if self._control == "pause": + raise Paused() + + # ----- lifecycle ----- + + async def run(self) -> str: + """Run to a terminal/paused state; returns the final status string.""" + await self._set_status(DownloadStatus.ACTIVE, error=None) + try: + pr = await self._probe_and_plan() + await self._transfer(pr) + await self._finalize() + await self._set_status(DownloadStatus.COMPLETED) + except Paused: + await self._persist_progress(force=True) + await self._set_status(DownloadStatus.PAUSED) + except Cancelled: + await self._close_writer() + self._remove_temp() + await self._set_status(DownloadStatus.CANCELLED) + except RemoteChanged: + await self._reset_for_restart() + await self._set_status( + DownloadStatus.QUEUED, error="remote file changed; restarting" + ) + except RetryableError as e: + await self._persist_progress(force=True) + await self._set_status(DownloadStatus.QUEUED, error=str(e)) + except FatalError as e: + await self._close_writer() + self._remove_temp() + await self._set_status(DownloadStatus.FAILED, error=str(e)) + except Exception as e: # unexpected -> treat as retryable + logging.warning( + "[model_downloader] %s unexpected error: %s", + self.spec.model_id, e, exc_info=True, + ) + await self._persist_progress(force=True) + await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}") + finally: + await self._close_writer() + return self.state.status + + # ----- probe + plan ----- + + async def _probe_and_plan(self): + pr = await probe(self.spec.url, credential_id=self.spec.credential_id) + if not pr.ok: + if pr.gated: + raise FatalError(gated_error_message(self.spec.url, pr)) + if pr.status == 0 or pr.status in _RETRYABLE_STATUSES: + raise RetryableError(pr.error or "probe failed") + raise FatalError(pr.error or f"probe returned HTTP {pr.status}") + + max_bytes = self._max_download_bytes() + if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes: + raise FatalError( + f"file size {pr.total_bytes} exceeds the maximum allowed " + f"download size {max_bytes} (--download-max-bytes)" + ) + + self._etag = pr.etag or self._etag + self.state.total_bytes = pr.total_bytes + await asyncio.to_thread( + queries.update_download, + self.spec.download_id, + final_url=pr.final_url, + total_bytes=pr.total_bytes, + accept_ranges=pr.accept_ranges, + etag=pr.etag, + last_modified=pr.last_modified, + ) + + seg_count = effective_segment_count( + pr.total_bytes, pr.accept_ranges, max(1, args.download_segments) + ) + existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id) + can_resume_segmented = ( + seg_count > 1 + and existing + and pr.total_bytes is not None + and existing[-1].end_offset == pr.total_bytes - 1 + ) + if can_resume_segmented and not self._segmented_part_valid(pr.total_bytes): + # The persisted per-segment offsets describe bytes in a preallocated + # .part that is now gone or the wrong size (e.g. the partial of a + # failed download was swept on restart, or removed by a fatal + # error). Trusting them would skip already-"complete" segments and + # leave zero-filled holes. Discard the offsets and re-plan fresh. + logging.info( + "[model_downloader] %s discarding segmented resume offsets " + "(preallocated .part missing or wrong size); restarting", + self.spec.model_id, + ) + self._remove_temp() + await asyncio.to_thread( + queries.replace_segments, self.spec.download_id, [] + ) + await asyncio.to_thread( + queries.update_download, self.spec.download_id, bytes_done=0 + ) + existing = [] + can_resume_segmented = False + + if can_resume_segmented: + # Resume an existing segmented plan. + self.state.segments = [ + SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done) + for s in existing + ] + elif seg_count > 1 and pr.total_bytes is not None: + plans = plan_segments(pr.total_bytes, seg_count) + await asyncio.to_thread( + queries.replace_segments, + self.spec.download_id, + [ + {"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0} + for p in plans + ], + ) + self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans] + else: + # Single-stream: one logical segment; bytes_done tracked on the row. + row = await asyncio.to_thread(queries.get_download, self.spec.download_id) + resume_from = row.bytes_done if row else 0 + end = (pr.total_bytes - 1) if pr.total_bytes else -1 + # ``row.bytes_done`` may be the SUM of per-segment offsets from a + # prior segmented run (a preallocated, non-contiguous .part). A + # single-stream resume writes a contiguous prefix, so the offset is + # only trustworthy when the on-disk file is exactly that many + # contiguous bytes. This guards the case where a download that ran + # segmented now resolves to one segment (server dropped + # Accept-Ranges, or --download-segments was lowered between runs): + # resuming over non-contiguous data would corrupt the output. + if resume_from > 0 and not self._contiguous_prefix_valid(resume_from): + logging.info( + "[model_downloader] %s discarding untrusted resume offset " + "%d (on-disk .part not a contiguous prefix); restarting", + self.spec.model_id, resume_from, + ) + resume_from = 0 + self._remove_temp() + if await asyncio.to_thread(queries.list_segments, self.spec.download_id): + await asyncio.to_thread( + queries.replace_segments, self.spec.download_id, [] + ) + await asyncio.to_thread( + queries.update_download, self.spec.download_id, bytes_done=0 + ) + self.state.segments = [SegmentRuntime(0, 0, end, resume_from)] + self._recompute_bytes_done() + return pr + + # ----- transfer ----- + + async def _transfer(self, pr) -> None: + self._writer = FileWriter(self.spec.temp_path) + await self._writer.open() + + segmented = len(self.state.segments) > 1 + if segmented and self.state.total_bytes: + await self._writer.preallocate(self.state.total_bytes) + await self._run_segmented() + else: + await self._run_single() + + await self._writer.flush() + + async def _run_segmented(self) -> None: + pending = [ + asyncio.ensure_future(self._run_segment(seg)) + for seg in self.state.segments + if seg.bytes_done < seg.length + ] + if not pending: + return + done, not_done = await asyncio.wait( + pending, return_when=asyncio.FIRST_EXCEPTION + ) + first_exc: Optional[BaseException] = None + for task in done: + exc = task.exception() + if exc is not None and first_exc is None: + first_exc = exc + if first_exc is not None: + for task in not_done: + task.cancel() + await asyncio.gather(*not_done, return_exceptions=True) + raise first_exc + + async def _run_segment(self, seg: SegmentRuntime) -> None: + offset = seg.start + seg.bytes_done + headers = { + "Range": f"bytes={offset}-{seg.end}", + "Accept-Encoding": "identity", + } + if self._etag: + headers["If-Range"] = self._etag + async with open_validated( + "GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers + ) as (resp, _final): + if resp.status == 200: + # Server ignored the range -> remote changed / no resume support. + raise RemoteChanged() + if resp.status not in (206,): + self._raise_for_status(resp.status) + async for chunk in resp.content.iter_chunked(args.download_chunk_size): + self._check_control() + # Never write past this segment's planned range: a + # non-conforming 206 that returns more than the requested + # bytes would otherwise overrun adjacent segments and the + # preallocated file. Cap the write and abort on overflow. + remaining = seg.length - seg.bytes_done + if remaining <= 0: + raise FatalError( + f"segment {seg.idx}: server returned more than the " + f"requested {seg.length} bytes" + ) + overflow = len(chunk) > remaining + if overflow: + chunk = chunk[:remaining] + await self._writer.write_at(offset, chunk) + offset += len(chunk) + seg.bytes_done += len(chunk) + self._recompute_bytes_done() + await self._persist_progress() + if overflow: + raise FatalError( + f"segment {seg.idx}: server returned more than the " + f"requested {seg.length} bytes" + ) + + async def _run_single(self) -> None: + seg = self.state.segments[0] + offset = seg.bytes_done # resume from here for single-stream + headers = {"Accept-Encoding": "identity"} + if offset > 0: + headers["Range"] = f"bytes={offset}-" + if self._etag: + headers["If-Range"] = self._etag + async with open_validated( + "GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers + ) as (resp, _final): + if offset > 0 and resp.status == 200: + # Resume not honoured -> start over from the beginning. Truncate + # the existing partial so stale trailing bytes from the prior + # attempt cannot survive past the new (possibly shorter) end. + offset = 0 + seg.bytes_done = 0 + self.state.bytes_done = 0 + await self._writer.truncate(0) + elif offset > 0 and resp.status != 206: + self._raise_for_status(resp.status) + elif offset == 0 and resp.status != 200: + self._raise_for_status(resp.status) + # Byte ceiling for this stream: the known total when the server + # reported a size, otherwise the configured maximum download size. + # Without a bound, a non-conforming response or an unknown-length + # stream (end == -1) that never closes could fill the disk (DoS). + limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes() + async for chunk in resp.content.iter_chunked(args.download_chunk_size): + self._check_control() + overflow = False + if limit is not None: + remaining = limit - offset + if remaining <= 0: + raise FatalError( + f"download exceeded the maximum size {limit} bytes" + ) + if len(chunk) > remaining: + chunk = chunk[:remaining] + overflow = True + await self._writer.write_at(offset, chunk) + offset += len(chunk) + seg.bytes_done = offset + self.state.bytes_done = offset + await self._persist_progress() + if overflow: + raise FatalError( + f"download exceeded the maximum size {limit} bytes" + ) + + def _max_download_bytes(self) -> Optional[int]: + """Configured maximum download size in bytes, or ``None`` if disabled.""" + cap = getattr(args, "download_max_bytes", 0) + return cap if cap and cap > 0 else None + + def _raise_for_status(self, status: int) -> None: + if status in (401, 403): + raise FatalError( + f"{redact_url(self.spec.url)} returned {status}; add/update an API key for " + f"this host at /api/download/credentials." + ) + if status in _RETRYABLE_STATUSES: + raise RetryableError(f"HTTP {status}") + raise FatalError(f"unexpected HTTP {status}") + + # ----- finalize / verify (PRD section 8.4) ----- + + async def _finalize(self) -> None: + self._check_control() + await self._close_writer() + await self._set_status(DownloadStatus.VERIFYING) + + total = self.state.total_bytes + segmented = len(self.state.segments) > 1 + if segmented: + # The .part was preallocated to total_bytes, so its on-disk size is + # not evidence of completeness: a segment that ends short (truncated + # 206 / server closes mid-range) leaves a zero-filled hole while the + # file size still equals total. Verify each segment wrote its full + # planned range, and trust the byte counter (== sum of segments) + # rather than os.path.getsize for the total check. + for seg in self.state.segments: + if seg.bytes_done != seg.length: + raise FatalError( + f"segment {seg.idx} incomplete: wrote {seg.bytes_done} " + f"of {seg.length} bytes" + ) + observed = self.state.bytes_done + else: + # Single-stream writes a contiguous prefix, so the on-disk size is + # an independent witness of how much actually landed. + observed = os.path.getsize(self.spec.temp_path) + if total is not None and observed != total: + raise FatalError( + f"size mismatch: wrote {observed} of {total} bytes" + ) + + # Structural gate (cheap, no full read) then optional sha256 (full read). + # Both failures are non-retryable (a truncated/corrupt or mismatched file + # will not heal on retry), so surface them as FatalError rather than + # letting the plain Exceptions fall through to the retryable handler. + # ``temp_path`` carries the ``.part`` suffix; pass ``dest_path`` so the + # structural check detects the real file format instead of skipping it. + try: + await asyncio.to_thread( + structural.validate, self.spec.temp_path, self.spec.dest_path + ) + if self.spec.expected_sha256: + await asyncio.to_thread( + checksum.verify_sha256, + self.spec.temp_path, + self.spec.expected_sha256, + ) + except (structural.StructuralError, checksum.ChecksumError) as e: + raise FatalError(str(e)) from e + + os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True) + os.replace(self.spec.temp_path, self.spec.dest_path) + logging.info( + "[model_downloader] completed %s (%d bytes)", + self.spec.model_id, observed, + ) + # Catalog into the assets system (blake3 dedup identity). Best-effort. + await dedup.register_completed(self.spec.dest_path) + + # ----- helpers ----- + + def _recompute_bytes_done(self) -> None: + self.state.bytes_done = sum(s.bytes_done for s in self.state.segments) + now = time.monotonic() + dt = now - self.state._last_time + if dt >= 0.5: + self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt + self.state._last_bytes = self.state.bytes_done + self.state._last_time = now + + async def _persist_progress(self, force: bool = False) -> None: + # Both the DB write and the websocket notify are gated by the same + # throttle: persisting hits SQLite, and notifying broadcasts to every + # client, so doing either per-chunk (small --download-chunk-size or + # many concurrent segments) would overwhelm both. Skip entirely inside + # the window; the next persist (or a forced one) ships the latest bytes. + now = time.monotonic() + if not force and now - self._last_persist < _PERSIST_INTERVAL: + return + self._last_persist = now + # SQLite is blocking; run it off the event loop per the queries module + # contract so progress persists don't stall the web server. + await asyncio.to_thread(self._write_progress) + if self._notify: + self._notify(self.spec.download_id) + + def _write_progress(self) -> None: + queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done) + for seg in self.state.segments: + if seg.end >= seg.start: # skip unknown-size sentinel + queries.update_segment_progress( + self.spec.download_id, seg.idx, seg.bytes_done + ) + + async def _reset_for_restart(self) -> None: + await self._close_writer() + self._remove_temp() + for seg in self.state.segments: + seg.bytes_done = 0 + self.state.bytes_done = 0 + await asyncio.to_thread( + queries.update_download, self.spec.download_id, bytes_done=0 + ) + if await asyncio.to_thread(queries.list_segments, self.spec.download_id): + await asyncio.to_thread( + queries.replace_segments, self.spec.download_id, [] + ) + + async def _close_writer(self) -> None: + if self._writer is not None: + try: + await self._writer.close() + except Exception: + logging.debug("[model_downloader] writer close error", exc_info=True) + self._writer = None + + def _segmented_part_valid(self, total_bytes: int) -> bool: + """True when the temp file is the preallocated segmented ``.part``. + + A segmented transfer preallocates the .part to ``total_bytes`` up front + and tracks how much of each range landed via per-segment offsets. Those + offsets are only trustworthy when the file they describe is still on + disk at its full preallocated size. A missing file (swept after a + failure, removed on a fatal error, deleted by hand) or a wrong-sized one + means the persisted offsets no longer correspond to real bytes and must + not be resumed over. Doing so would skip "complete" segments and leave + zero-filled holes that pass the size-only verification gate. + """ + try: + return os.path.getsize(self.spec.temp_path) == total_bytes + except OSError: + return False + + def _contiguous_prefix_valid(self, prefix_len: int) -> bool: + """True when the temp file is exactly ``prefix_len`` contiguous bytes. + + Single-stream resume appends sequentially, so a valid resume point + implies the .part size equals the persisted offset. A larger file (e.g. + one preallocated to ``total_bytes`` by a previous segmented run) or a + missing/short file means the persisted offset is not a trustworthy + contiguous prefix and must not be resumed over. + """ + try: + return os.path.getsize(self.spec.temp_path) == prefix_len + except OSError: + return False + + def _remove_temp(self) -> None: + try: + os.remove(self.spec.temp_path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning( + "[model_downloader] could not remove %s: %s", self.spec.temp_path, e + ) + + async def _set_status(self, status: str, error: Optional[str] = None) -> None: + # ``error`` is authoritative: passing None clears any prior failure + # text so transitions out of a failure state (retry/success) don't + # leave stale messages on RuntimeState or in the persisted row. + self.state.status = status + self.state.error = error + fields = {"status": status, "bytes_done": self.state.bytes_done, "error": error} + if status == DownloadStatus.QUEUED: + fields["attempts"] = self.spec.attempts + 1 + self.spec.attempts += 1 + await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields) + if self._notify: + self._notify(self.spec.download_id) diff --git a/app/model_downloader/engine/planner.py b/app/model_downloader/engine/planner.py new file mode 100644 index 000000000..175669165 --- /dev/null +++ b/app/model_downloader/engine/planner.py @@ -0,0 +1,51 @@ +"""Segment planning. + +Split a known byte range into S roughly-equal segments, each fetched by its +own coroutine with ``Range: bytes=start-end``. Falls back to a single segment +when the server doesn't support ranges or the size is unknown/too small for +segmentation to be worthwhile. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +# Below this size, the per-connection setup cost outweighs any parallelism. +_MIN_SEGMENT_BYTES = 1 * 1024 * 1024 + + +@dataclass(frozen=True) +class SegmentPlan: + idx: int + start: int + end: int # inclusive + + @property + def length(self) -> int: + return self.end - self.start + 1 + + +def effective_segment_count( + total_bytes: int | None, accept_ranges: bool, configured: int +) -> int: + """How many segments to actually use for this file.""" + if not accept_ranges or total_bytes is None or total_bytes <= 0: + return 1 + by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES) + return max(1, min(configured, by_size)) + + +def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]: + """Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total).""" + if total_bytes <= 0 or num_segments <= 1: + return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))] + base = total_bytes // num_segments + plans: list[SegmentPlan] = [] + start = 0 + for i in range(num_segments): + # Last segment soaks up the remainder. + length = base if i < num_segments - 1 else total_bytes - start + end = start + length - 1 + plans.append(SegmentPlan(idx=i, start=start, end=end)) + start = end + 1 + return plans diff --git a/app/model_downloader/engine/writer.py b/app/model_downloader/engine/writer.py new file mode 100644 index 000000000..467d1faac --- /dev/null +++ b/app/model_downloader/engine/writer.py @@ -0,0 +1,110 @@ +"""Positioned, off-loop file writes. + +Network I/O stays on the event loop; every blocking disk op (preallocate, +positioned write, fsync) is run in a bounded thread pool via +``run_in_executor`` so downloads never stall inference or the web server. + +A single file descriptor is opened for the whole download. Segments write to +their own offsets with ``os.pwrite`` — which is offset-addressed and atomic +per call, so concurrent segment writers need no extra locking. Per-chunk +fsync is avoided; we fsync once at completion. + +``os.pwrite`` is unavailable on Windows, so there we fall back to +``os.lseek`` + ``os.write`` guarded by a per-writer lock (the seek/write pair +is not atomic, so concurrent segment writers must be serialized). +""" + +from __future__ import annotations + +import asyncio +import os +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +# One shared, bounded pool for all download disk I/O. +_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer") + +_HAS_PWRITE = hasattr(os, "pwrite") + +# On Windows ``os.open`` defaults to text mode, which translates every ``\n`` +# byte into ``\r\n`` on write and corrupts binary payloads (the file grows by +# one byte per 0x0A). ``O_BINARY`` disables that translation; it does not exist +# on POSIX, where the default is already binary. +_O_BINARY = getattr(os, "O_BINARY", 0) + + +class FileWriter: + """Owns the ``.part`` file descriptor for one download.""" + + def __init__(self, path: str) -> None: + self.path = path + self._fd: Optional[int] = None + # Serializes lseek+write on platforms without os.pwrite (Windows). + self._seek_lock = threading.Lock() + + def _open(self) -> None: + os.makedirs(os.path.dirname(self.path), exist_ok=True) + self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT | _O_BINARY, 0o644) + + async def open(self) -> None: + await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open) + + async def preallocate(self, size: int) -> None: + """Grow the file to ``size`` so segments write to their offsets.""" + if self._fd is None or size <= 0: + return + await asyncio.get_running_loop().run_in_executor( + _EXECUTOR, os.ftruncate, self._fd, size + ) + + async def truncate(self, size: int = 0) -> None: + """Truncate the file to ``size`` bytes (default: empty it).""" + if self._fd is None: + return + await asyncio.get_running_loop().run_in_executor( + _EXECUTOR, os.ftruncate, self._fd, size + ) + + def _pwrite_all(self, data: bytes, offset: int) -> None: + """A positioned write may write fewer bytes than requested (signal + interruption, near-ENOSPC); loop until every byte lands so we never + leave a gap while the caller advances by the full chunk length. + + Uses ``os.pwrite`` where available (offset-addressed, atomic per call). + On Windows it falls back to ``os.lseek`` + ``os.write`` under a lock, + since that pair is not atomic across concurrent segment writers.""" + assert self._fd is not None, "writer not opened" + view = memoryview(data) + written = 0 + total = len(view) + while written < total: + if _HAS_PWRITE: + n = os.pwrite(self._fd, view[written:], offset + written) + else: + with self._seek_lock: + os.lseek(self._fd, offset + written, os.SEEK_SET) + n = os.write(self._fd, view[written:]) + if n == 0: + raise OSError( + f"positioned write wrote 0 bytes at offset {offset + written} " + f"({written}/{total} bytes written)" + ) + written += n + + async def write_at(self, offset: int, data: bytes) -> None: + assert self._fd is not None, "writer not opened" + await asyncio.get_running_loop().run_in_executor( + _EXECUTOR, self._pwrite_all, data, offset + ) + + async def flush(self) -> None: + if self._fd is None: + return + await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd) + + async def close(self) -> None: + if self._fd is None: + return + fd, self._fd = self._fd, None + await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd) diff --git a/app/model_downloader/manager.py b/app/model_downloader/manager.py new file mode 100644 index 000000000..60f597237 --- /dev/null +++ b/app/model_downloader/manager.py @@ -0,0 +1,454 @@ +"""Public facade for the download manager. + +This is the only object the server imports. It validates requests, owns the +:class:`Scheduler`, and exposes a small async API plus read models for status. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import uuid +from typing import Callable, Optional + +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.net.probe import gated_error_message, probe +from app.model_downloader.scheduler import SCHEDULER +from app.model_downloader.security import paths +from app.model_downloader.net.http import redact_url +from app.model_downloader.security.allowlist import ( + ALLOWED_MODEL_EXTENSIONS, + filename_extension, + is_host_allowed_url, + is_url_downloadable, + url_path_extension, +) +from app.model_downloader.security.paths import InvalidModelId + +# Non-terminal statuses: an existing row in one of these blocks a re-enqueue. +_LIVE_STATUSES = ( + DownloadStatus.QUEUED, + DownloadStatus.ACTIVE, + DownloadStatus.PAUSED, + DownloadStatus.VERIFYING, +) + + +class DownloadError(Exception): + """A user-facing error with a stable machine-readable code.""" + + def __init__(self, code: str, message: str, status: int = 400) -> None: + super().__init__(message) + self.code = code + self.message = message + self.http_status = status + + +class DownloadManager: + def __init__(self) -> None: + self._scheduler = SCHEDULER + self._notify_cb: Optional[Callable[[str], None]] = None + # Serializes the "check for a live download, then write" critical section + # per model_id. ``downloads`` has no uniqueness constraint on model_id + # (history rows are kept), so without this two concurrent enqueue/resume + # calls could both pass the live check and admit two jobs sharing one + # temp/dest path. The manager is a process singleton over a local SQLite + # DB, so an in-process lock is sufficient (and avoids a migration). + self._model_locks: dict[str, asyncio.Lock] = {} + + def set_notify(self, cb: Optional[Callable[[str], None]]) -> None: + self._notify_cb = cb + self._scheduler.set_notify(cb) + + async def start(self) -> None: + await self._scheduler.start() + + # ----- enqueue ----- + + async def enqueue( + self, + url: str, + model_id: str, + *, + priority: int = 0, + expected_sha256: Optional[str] = None, + allow_any_extension: bool = False, + credential_id: Optional[str] = None, + ) -> str: + # Coarse gate first: host/scheme must be allowlisted, and any extension + # present in the URL path must be a known model type. A URL whose path + # carries NO extension (e.g. Civitai's ``/api/download/models/``) is + # admitted here and its real extension is resolved from the network + # below before the download is finally accepted. + if allow_any_extension: + if not is_host_allowed_url(url): + raise DownloadError( + "URL_NOT_ALLOWED", + "URL is not on the download allowlist (host/scheme).", + ) + elif not is_url_downloadable(url): + raise DownloadError( + "URL_NOT_ALLOWED", + "URL is not on the download allowlist (host/scheme/extension).", + ) + + # When the URL path has no extension, follow it to where it resolves and + # adopt the real extension from the response, forcing the stored + # filename to match. Skipped when the caller opted into any extension. + if not allow_any_extension and url_path_extension(url) == "": + resolved_ext = await self._resolve_extension(url, credential_id) + model_id = paths.apply_extension(model_id, resolved_ext) + + try: + paths.parse_model_id(model_id, allow_any_extension) + dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension) + except InvalidModelId as e: + raise DownloadError("INVALID_MODEL_ID", str(e)) + + if await asyncio.to_thread( + paths.resolve_existing, model_id, allow_any_extension + ): + raise DownloadError( + "ALREADY_AVAILABLE", + f"Model already exists on disk: {model_id}", + status=409, + ) + download_id = str(uuid.uuid4()) + # Hold the per-model lock across the live check and the insert so a + # concurrent enqueue/resume for the same model_id cannot interleave + # between them and create a second job against the same temp/dest path. + async with self._model_lock(model_id): + if await self._has_live_download(model_id): + raise DownloadError( + "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress.", + status=409, + ) + await asyncio.to_thread( + queries.insert_download, + { + "id": download_id, + "url": url, + "model_id": model_id, + "dest_path": dest_path, + "temp_path": temp_path, + "status": DownloadStatus.QUEUED, + "priority": priority, + "expected_sha256": expected_sha256, + "credential_id": credential_id, + "allow_any_extension": allow_any_extension, + }, + ) + logging.info("[model_downloader] enqueued %s -> %s", redact_url(url), model_id) + await self._scheduler.pump() + return download_id + + async def _resolve_extension( + self, url: str, credential_id: Optional[str] + ) -> str: + """Follow ``url`` to its final response and return the real extension. + + Used for allowlisted URLs whose path has no extension (e.g. Civitai + download endpoints): the filename lives in the ``Content-Disposition`` + header or the post-redirect URL. Raises :class:`DownloadError` when the + URL can't be resolved, needs credentials, or resolves to something that + is not a known model file — so we never persist a bogus destination. + """ + pr = await probe(url, credential_id=credential_id) + if not pr.ok: + if pr.gated: + raise DownloadError( + "GATED_REPO" if pr.is_gated_repo else "CREDENTIALS_REQUIRED", + gated_error_message(url, pr), + status=401, + ) + raise DownloadError( + "URL_RESOLVE_FAILED", + f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}", + status=502, + ) + ext = filename_extension(pr.filename) if pr.filename else "" + if ext not in ALLOWED_MODEL_EXTENSIONS: + raise DownloadError( + "URL_NOT_ALLOWED", + f"URL resolves to {pr.filename or ''!r}, which is not a " + f"known model file type {ALLOWED_MODEL_EXTENSIONS}.", + ) + return ext + + def _model_lock(self, model_id: str) -> asyncio.Lock: + # Lazily create one lock per model_id. There is no ``await`` between the + # lookup and the insert, so under the single asyncio thread this is + # atomic and cannot hand out two different locks for the same model_id. + lock = self._model_locks.get(model_id) + if lock is None: + lock = asyncio.Lock() + self._model_locks[model_id] = lock + return lock + + async def _has_live_download( + self, model_id: str, *, exclude_id: Optional[str] = None + ) -> bool: + rows = await asyncio.to_thread(queries.list_downloads) + return any( + r.model_id == model_id + and r.id != exclude_id + and r.status in _LIVE_STATUSES + for r in rows + ) + + # ----- control ----- + + async def pause(self, download_id: str) -> None: + job = self._scheduler.get_job(download_id) + if job is not None: + job.request_pause() + return + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status == DownloadStatus.QUEUED: + await asyncio.to_thread( + queries.update_download, download_id, status=DownloadStatus.PAUSED + ) + + async def resume(self, download_id: str) -> None: + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status not in (DownloadStatus.PAUSED, DownloadStatus.FAILED): + return + # Re-queueing a paused/failed row must respect the single-live-per-model + # invariant: another download (e.g. a newer enqueue) may already be live + # for this model_id and would share this row's temp/dest path. Hold the + # per-model lock across the check and the status flip, and exclude this + # row itself (a paused row is already a "live" status). + async with self._model_lock(row.model_id): + if await self._has_live_download(row.model_id, exclude_id=download_id): + raise DownloadError( + "ALREADY_DOWNLOADING", + f"A download for {row.model_id} is already in progress.", + status=409, + ) + await asyncio.to_thread( + queries.update_download, + download_id, + status=DownloadStatus.QUEUED, + error=None, + ) + await self._scheduler.pump() + + async def cancel(self, download_id: str) -> None: + job = self._scheduler.get_job(download_id) + if job is not None: + job.request_cancel() + return + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status in _LIVE_STATUSES: + import os + + try: + os.remove(row.temp_path) + except OSError: + pass + await asyncio.to_thread( + queries.update_download, download_id, status=DownloadStatus.CANCELLED + ) + + async def set_priority(self, download_id: str, priority: int) -> None: + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + await asyncio.to_thread( + queries.update_download, download_id, priority=priority + ) + # Admission-order only; a higher priority is + # picked up the next time a slot frees. Pump in case a slot is free now. + await self._scheduler.pump() + + async def delete(self, download_id: str) -> None: + """Delete a terminal download so it stays gone from history. + + Refuses to delete a live download so a record is never removed out from + under a running worker; cancel it first. Any leftover ``.part`` temp + file (e.g. from a failed transfer) is removed, but the finished model + file on disk is never touched. + """ + if self._scheduler.get_job(download_id) is not None: + raise DownloadError( + "DOWNLOAD_ACTIVE", + "Cannot delete a download that is still in progress.", + status=409, + ) + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status in _LIVE_STATUSES: + raise DownloadError( + "DOWNLOAD_ACTIVE", + "Cannot delete a download that is still in progress.", + status=409, + ) + + try: + os.remove(row.temp_path) + except OSError: + pass + await asyncio.to_thread(queries.delete_download, download_id) + + async def clear(self) -> int: + """Delete all terminal downloads from history in one transaction. + + Skips anything still live (queued/active/paused/verifying, or a running + job) so an in-flight download is never removed out from under a worker. + Finished model files on disk are never touched; only leftover ``.part`` + temp files from failed/cancelled transfers are removed. Returns the + number of history rows deleted. + """ + + rows = await asyncio.to_thread(queries.list_downloads) + deletable = [ + r + for r in rows + if r.status not in _LIVE_STATUSES + and self._scheduler.get_job(r.id) is None + ] + if not deletable: + return 0 + for r in deletable: + try: + os.remove(r.temp_path) + except OSError: + pass + return await asyncio.to_thread( + queries.delete_downloads, [r.id for r in deletable] + ) + + # ----- read models ----- + + def _view(self, row) -> dict: + """Combine the persisted row with live in-memory progress, if running.""" + job = self._scheduler.get_job(row.id) + bytes_done = row.bytes_done + total = row.total_bytes + speed = None + eta = None + segments = None + if job is not None: + st = job.state + bytes_done = st.bytes_done + total = st.total_bytes if st.total_bytes is not None else total + speed = st.speed_bps + eta = st.eta_seconds + segments = [ + {"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length} + for s in st.segments + if s.end >= s.start + ] + progress = (bytes_done / total) if total else None + return { + "download_id": row.id, + "model_id": row.model_id, + "url": redact_url(row.url), + "status": row.status, + "priority": row.priority, + "total_bytes": total, + "bytes_done": bytes_done, + "progress": progress, + "speed_bps": speed, + "eta_seconds": eta, + "segments": segments, + "error": row.error, + "created_at": row.created_at, + "updated_at": row.updated_at, + } + + def _view_from_state(self, job) -> dict: + """Build a view purely from the live in-memory job state (no DB).""" + st = job.state + return { + "download_id": st.download_id, + "model_id": st.model_id, + "url": redact_url(st.url), + "status": st.status, + "priority": st.priority, + "total_bytes": st.total_bytes, + "bytes_done": st.bytes_done, + "progress": st.progress, + "speed_bps": st.speed_bps, + "eta_seconds": st.eta_seconds, + "segments": [ + {"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length} + for s in st.segments + if s.end >= s.start + ], + "error": st.error, + } + + def status_sync(self, download_id: str) -> Optional[dict]: + """Synchronous status read for the websocket notify path. + + Uses live in-memory state when the job is running (no DB round-trip on + the hot path); falls back to a quick DB read otherwise. + """ + job = self._scheduler.get_job(download_id) + if job is not None: + return self._view_from_state(job) + row = queries.get_download(download_id) + return self._view(row) if row is not None else None + + async def status(self, download_id: str) -> Optional[dict]: + row = await asyncio.to_thread(queries.get_download, download_id) + return self._view(row) if row is not None else None + + async def list(self) -> list[dict]: + rows = await asyncio.to_thread(queries.list_downloads) + return [self._view(r) for r in rows] + + async def availability(self, models: dict[str, str]) -> dict[str, dict]: + """Bulk per-id ``{state, progress, ...}`` for the frontend poll. + + ``state`` is ``available`` (on disk), ``downloading`` (live row), or + ``missing``. Cheap: a path lookup plus an in-memory/DB status check. + """ + rows = await asyncio.to_thread(queries.list_downloads) + by_model: dict[str, object] = {} + for r in rows: + if r.status in _LIVE_STATUSES or r.model_id not in by_model: + by_model[r.model_id] = r + + # ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a + # non-disallowed extension); URLs whose extension is only known after a + # network resolve — e.g. Civitai download endpoints — report allowed. + out: dict[str, dict] = {} + for model_id, url in models.items(): + try: + exists = await asyncio.to_thread(paths.resolve_existing, model_id) + except InvalidModelId: + out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)} + continue + if exists: + out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)} + continue + row = by_model.get(model_id) + if row is not None and row.status in _LIVE_STATUSES: + view = self._view(row) + out[model_id] = { + "state": "downloading", + "url_allowed": is_url_downloadable(url), + "download_id": view["download_id"], + "progress": view["progress"], + "bytes_done": view["bytes_done"], + "total_bytes": view["total_bytes"], + "speed_bps": view["speed_bps"], + } + else: + out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)} + return out + + +DOWNLOAD_MANAGER = DownloadManager() diff --git a/app/model_downloader/net/http.py b/app/model_downloader/net/http.py new file mode 100644 index 000000000..8112b7a20 --- /dev/null +++ b/app/model_downloader/net/http.py @@ -0,0 +1,148 @@ +"""Manual, validated redirect-following request opener. + +Automatic redirects are disabled. We follow hops ourselves +so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL, +(b) recompute which stored credential — if any — applies to that hop's host, +and (c) let the connector's resolver screen the IP. This is the single place +that attaches credentials, so a token can never ride a redirect to a CDN host. +""" + +from __future__ import annotations + +import logging +import re +from contextlib import asynccontextmanager +from typing import AsyncIterator, Optional +from urllib.parse import unquote, urljoin, urlsplit, urlunsplit + +import aiohttp + +from app.model_downloader.credentials.resolver import resolve_auth_for_hop +from app.model_downloader.net.session import get_session +from app.model_downloader.security.ssrf import ( + MAX_REDIRECTS, + SSRFError, + check_redirect_hop, +) + +_REDIRECT_CODES = {301, 302, 303, 307, 308} +DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120) + + +def redact_url(url: str) -> str: + """Drop the query string so a query-scheme secret is never logged/stored.""" + try: + parts = urlsplit(url) + except ValueError: + return "" + return urlunsplit(parts._replace(query="")) + + +_CD_FILENAME_STAR = re.compile( + r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE +) +_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE) +_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE) + + +def filename_from_content_disposition(value: Optional[str]) -> Optional[str]: + """Extract the download filename from a ``Content-Disposition`` header. + + Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain + ``filename=`` form. Any directory components in the value are stripped so a + hostile header can only influence the *name*, never the target directory. + Returns ``None`` when no filename is present. + """ + if not value: + return None + for pat, decode in ( + (_CD_FILENAME_STAR, True), + (_CD_FILENAME_QUOTED, False), + (_CD_FILENAME_BARE, False), + ): + m = pat.search(value) + if not m: + continue + raw = m.group(1).strip().strip('"') + if decode: + try: + raw = unquote(raw) + except Exception: + pass + name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip() + if name: + return name + return None + + +async def _resolve_final_response( + method: str, + url: str, + credential_id: Optional[str], + base_headers: dict[str, str], + timeout: aiohttp.ClientTimeout, +) -> tuple[aiohttp.ClientResponse, str]: + """Follow redirects manually until a non-redirect response. + + Each intermediate redirect response is released before the next hop. + Returns the final ``(response, final_url)``; the caller owns releasing it. + """ + session = await get_session() + current = url + hops = 0 + while True: + check_redirect_hop(current, is_initial_url=(hops == 0)) + parts = urlsplit(current) + auth = await resolve_auth_for_hop( + parts.hostname or "", parts.scheme, explicit_credential_id=credential_id + ) + req_headers = dict(base_headers) + req_url = current + if auth is not None: + req_headers.update(auth.headers) + req_url = auth.apply_to_url(current) + + resp = await session.request( + method, + req_url, + allow_redirects=False, + headers=req_headers, + timeout=timeout, + ) + if resp.status in _REDIRECT_CODES and resp.headers.get("Location"): + next_url = urljoin(str(resp.url), resp.headers["Location"]) + await resp.release() + hops += 1 + if hops > MAX_REDIRECTS: + raise SSRFError( + f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}" + ) + current = next_url + continue + return resp, redact_url(str(resp.url)) + + +@asynccontextmanager +async def open_validated( + method: str, + url: str, + *, + credential_id: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT, +) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]: + """Open ``method url`` following redirects manually and validated. + + Yields ``(response, final_url)`` where ``final_url`` is redacted of any + query string. The response is released automatically on exit. + """ + resp, final_url = await _resolve_final_response( + method, url, credential_id, dict(headers or {}), timeout + ) + try: + yield resp, final_url + finally: + try: + await resp.release() + except Exception: # pragma: no cover - best-effort cleanup + logging.debug("[model_downloader] response release error", exc_info=True) diff --git a/app/model_downloader/net/probe.py b/app/model_downloader/net/probe.py new file mode 100644 index 000000000..eca0c7fbb --- /dev/null +++ b/app/model_downloader/net/probe.py @@ -0,0 +1,157 @@ +"""Pre-download probe. + +Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a +range-support test — to discover ``Content-Length``, ``Accept-Ranges``, +``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace +LFS files the true size also appears in the non-standard ``X-Linked-Size`` +header, which we read as a fallback. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Optional +from urllib.parse import urlparse, urlsplit + +import aiohttp + +from app.model_downloader.net.http import ( + filename_from_content_disposition, + open_validated, + redact_url, +) +from app.model_downloader.net.session import parse_int_header + +_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30) + + +@dataclass +class ProbeResult: + ok: bool + status: int + final_url: Optional[str] = None + total_bytes: Optional[int] = None + accept_ranges: bool = False + etag: Optional[str] = None + last_modified: Optional[str] = None + gated: bool = False # 401/403 — needs (or has wrong) credentials + error: Optional[str] = None + # HuggingFace's ``X-Error-Code`` header (e.g. ``GatedRepo``, + # ``RepoNotFound``) when the host reports one. Lets us tell "this repo is + # gated — request access" apart from "you just need a token". + error_code: Optional[str] = None + # Filename the server intends this response to be saved as: the + # ``Content-Disposition`` name if present, else the post-redirect URL's + # basename. Used to resolve the real extension for URLs (e.g. Civitai's + # ``/api/download`` endpoints) that carry no extension in their path. + filename: Optional[str] = None + + @property + def is_gated_repo(self) -> bool: + """True when the host says the repo is gated (access must be granted). + + Distinct from a plain missing/invalid token: even a valid credential + won't help until the user accepts the model's terms on its page. + """ + return (self.error_code or "").lower() == "gatedrepo" + + +def _total_from_content_range(value: Optional[str]) -> Optional[int]: + # "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None + if not value or "/" not in value: + return None + total = value.rsplit("/", 1)[1].strip() + return parse_int_header(total) + + +def _filename_from_response( + content_disposition: Optional[str], final_url: Optional[str] +) -> Optional[str]: + name = filename_from_content_disposition(content_disposition) + if name: + return name + if final_url: + base = urlsplit(final_url).path.rsplit("/", 1)[-1] + if base: + return base + return None + + +async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult: + """Probe ``url`` and return discovered metadata, failing soft.""" + try: + async with open_validated( + "GET", + url, + credential_id=credential_id, + headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"}, + timeout=_PROBE_TIMEOUT, + ) as (resp, final_url): + if resp.status in (401, 403): + error_code = resp.headers.get("X-Error-Code") + error_message = resp.headers.get("X-Error-Message") + return ProbeResult( + ok=False, status=resp.status, final_url=final_url, gated=True, + error_code=error_code, + error=( + error_message + or f"host returned {resp.status} (authentication required)" + ), + ) + if resp.status not in (200, 206): + return ProbeResult( + ok=False, status=resp.status, final_url=final_url, + error=f"probe returned HTTP {resp.status}", + ) + + headers = resp.headers + accept_ranges = False + total: Optional[int] = None + if resp.status == 206: + accept_ranges = True + total = _total_from_content_range(headers.get("Content-Range")) + else: # 200: server ignored the range + accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes" + total = parse_int_header(headers.get("Content-Length")) + + if total is None: + total = parse_int_header(headers.get("X-Linked-Size")) + + return ProbeResult( + ok=True, + status=resp.status, + final_url=final_url, + total_bytes=total, + accept_ranges=accept_ranges, + etag=headers.get("ETag"), + last_modified=headers.get("Last-Modified"), + filename=_filename_from_response( + headers.get("Content-Disposition"), final_url + ), + ) + except Exception as e: # network / SSRF / timeout + host = urlparse(url).netloc or "" + logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__) + return ProbeResult(ok=False, status=0, error="probe failed: network error") + + +def gated_error_message(url: str, pr: ProbeResult) -> str: + """Build a user-facing message for a gated/auth-required probe result. + + Distinguishes a *gated* repo (access must be requested/granted on the model + page — a token alone is not enough) from a plain missing/invalid credential. + """ + redacted = redact_url(url) + if pr.is_gated_repo: + detail = (pr.error or "access is restricted").rstrip() + if detail and not detail.endswith((".", "!", "?")): + detail += "." + return ( + f"{redacted} is a gated model — {detail} Request access on the model's " + f"page, add an API key for this host at /api/download/credentials, and retry." + ) + return ( + f"{redacted} requires authentication. Add an API key for this host at " + f"/api/download/credentials and retry." + ) diff --git a/app/model_downloader/net/session.py b/app/model_downloader/net/session.py new file mode 100644 index 000000000..8270e051c --- /dev/null +++ b/app/model_downloader/net/session.py @@ -0,0 +1,72 @@ +"""Lazily-created shared :class:`aiohttp.ClientSession`. + +A single session reuses TLS handshakes and TCP connections across the probe +and the many segment GETs to the same host (HuggingFace is the dominant +case), which is a large speedup on cold connections and exactly the +connection-reuse strategy that lets us match aria2c. + +The connector uses :class:`ValidatingResolver` so every connection — initial +or post-redirect — is screened for private/special-use IPs at connect time. +TLS is pinned to certifi's CA bundle because the OS trust store is not wired +up on some Python installs (python.org macOS, slim containers). +""" + +from __future__ import annotations + +import asyncio +import ssl +from typing import Optional + +import aiohttp + +try: + import certifi + _CA_FILE = certifi.where() +except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp + _CA_FILE = None + +from comfy.cli_args import args +from app.model_downloader.security.ssrf import ValidatingResolver + +_session: Optional[aiohttp.ClientSession] = None +_lock = asyncio.Lock() + + +def ssl_context() -> ssl.SSLContext: + if _CA_FILE is not None: + return ssl.create_default_context(cafile=_CA_FILE) + return ssl.create_default_context() + + +async def get_session() -> aiohttp.ClientSession: + """Return the shared session, creating it on first use.""" + global _session + if _session is not None and not _session.closed: + return _session + async with _lock: + if _session is None or _session.closed: + connector = aiohttp.TCPConnector( + limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)), + ssl=ssl_context(), + resolver=ValidatingResolver(), + ) + _session = aiohttp.ClientSession(connector=connector) + return _session + + +async def close_session() -> None: + global _session + if _session is not None and not _session.closed: + await _session.close() + _session = None + + +def parse_int_header(value: Optional[str]) -> Optional[int]: + """Parse a non-negative integer header value, or None if bad/absent.""" + if not value: + return None + try: + n = int(value) + except (TypeError, ValueError): + return None + return n if n >= 0 else None diff --git a/app/model_downloader/scheduler.py b/app/model_downloader/scheduler.py new file mode 100644 index 000000000..b41d2cb3b --- /dev/null +++ b/app/model_downloader/scheduler.py @@ -0,0 +1,177 @@ +"""Priority scheduler + lifecycle. + +Owns the set of running jobs and admits queued downloads up to a global +concurrency limit (K), highest priority first, FIFO within a priority. Runs +entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing, +DB) is offloaded by the job/writer layers. + +On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a +previous run are reset to ``queued`` and resumed from persisted offsets, and +orphaned ``.part`` files with no live download row are swept. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import random +import time +from typing import Callable, Optional + +from comfy.cli_args import args +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.engine.job import DownloadJob, JobSpec +from app.model_downloader.security import paths + +# Backoff for retryable failures +_BACKOFF_BASE = 2.0 +_BACKOFF_CAP = 300.0 +_MAX_ATTEMPTS = 6 + + +class Scheduler: + def __init__(self) -> None: + self._jobs: dict[str, DownloadJob] = {} + self._tasks: dict[str, asyncio.Task] = {} + self._backoff_until: dict[str, float] = {} + self._pump_lock = asyncio.Lock() + self._notify_cb: Optional[Callable[[str], None]] = None + self._started = False + + @property + def max_active(self) -> int: + return max(1, getattr(args, "download_max_active", 3)) + + def set_notify(self, cb: Optional[Callable[[str], None]]) -> None: + self._notify_cb = cb + + def get_job(self, download_id: str) -> Optional[DownloadJob]: + return self._jobs.get(download_id) + + def is_active(self, download_id: str) -> bool: + return download_id in self._tasks + + # ----- startup ----- + + async def start(self) -> None: + if self._started: + return + self._started = True + try: + await asyncio.to_thread(queries.reconcile_live_downloads) + await asyncio.to_thread(self._sweep_orphan_temp_files) + except Exception as e: + logging.warning("[model_downloader] startup reconcile failed: %s", e) + await self.pump() + + @staticmethod + def _sweep_orphan_temp_files() -> None: + """Remove ``.part`` files not referenced by a resumable download row. + + Resumable partials are preserved; only truly orphaned temp files from + crashed runs are deleted. ``FAILED`` is included because + :meth:`DownloadManager.resume` explicitly permits resuming a + retry-exhausted failed row: deleting its partial here while the + per-segment offsets survive in the DB would make the next resume + preallocate a fresh sparse file, skip every "complete" segment, and + leave zero-filled holes that pass the size-only verification gate. + """ + live = { + row.temp_path + for row in queries.list_downloads() + if row.status + in ( + DownloadStatus.QUEUED, + DownloadStatus.PAUSED, + DownloadStatus.FAILED, + ) + } + for path in paths.iter_all_tmp_paths(): + if path in live: + continue + try: + os.remove(path) + logging.info("[model_downloader] removed orphan temp file: %s", path) + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + # ----- admission ----- + + async def pump(self) -> None: + async with self._pump_lock: + slots = self.max_active - len(self._tasks) + if slots <= 0: + return + now = time.monotonic() + candidates = await asyncio.to_thread(queries.list_queued_downloads) + for row in candidates: + if slots <= 0: + break + if row.id in self._tasks: + continue + if self._backoff_until.get(row.id, 0.0) > now: + continue + self._admit(row) + slots -= 1 + + def _admit(self, row) -> None: + spec = JobSpec( + download_id=row.id, + url=row.url, + model_id=row.model_id, + dest_path=row.dest_path, + temp_path=row.temp_path, + priority=row.priority, + credential_id=row.credential_id, + expected_sha256=row.expected_sha256, + allow_any_extension=row.allow_any_extension, + etag=row.etag, + attempts=row.attempts, + ) + job = DownloadJob(spec, notify_cb=self._notify_cb) + self._jobs[row.id] = job + self._tasks[row.id] = asyncio.ensure_future(self._run_job(job)) + + async def _run_job(self, job: DownloadJob) -> None: + download_id = job.spec.download_id + status = DownloadStatus.FAILED + try: + status = await job.run() + except Exception as e: # run() is defensive, but never let a task die silently + logging.error("[model_downloader] job %s crashed: %s", download_id, e) + queries.update_download( + download_id, + status=DownloadStatus.FAILED, + error=f"internal error: {e}", + ) + if self._notify_cb: + self._notify_cb(download_id) + finally: + self._tasks.pop(download_id, None) + self._jobs.pop(download_id, None) + + if status == DownloadStatus.QUEUED: + if job.spec.attempts >= _MAX_ATTEMPTS: + queries.update_download( + download_id, + status=DownloadStatus.FAILED, + error=f"giving up after {job.spec.attempts} attempts", + ) + if self._notify_cb: + self._notify_cb(download_id) + else: + delay = min( + _BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts + ) + random.uniform(0, 1.0) + self._backoff_until[download_id] = time.monotonic() + delay + asyncio.ensure_future(self._delayed_pump(delay)) + await self.pump() + + async def _delayed_pump(self, delay: float) -> None: + await asyncio.sleep(delay) + await self.pump() + + +SCHEDULER = Scheduler() diff --git a/app/model_downloader/security/allowlist.py b/app/model_downloader/security/allowlist.py new file mode 100644 index 000000000..f1e0ecbc0 --- /dev/null +++ b/app/model_downloader/security/allowlist.py @@ -0,0 +1,140 @@ +"""URL allowlist for server-side model fetches. + +Default-deny. A URL is downloadable only when its parsed host + scheme are +allowlisted AND (unless explicitly relaxed) its final filename ends in a +known model extension. + +The built-in host defaults mirror the frontend's ``isModelDownloadable`` +allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts`` +extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname`` +(never a raw string prefix) so userinfo tricks like +``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the +metadata IP — cannot slip past. +""" + +from __future__ import annotations + +from urllib.parse import urlparse + +from comfy.cli_args import args + +# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai / +# localhost). Extra hosts from --download-allowed-hosts are https-only. +_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = { + "huggingface.co": {"https"}, + "civitai.com": {"https"}, + "localhost": {"http", "https"}, + "127.0.0.1": {"http", "https"}, +} + +# Hosts for which loopback addresses are intentionally permitted (the localhost +# "download a local model" feature). Every other host's loopback resolution is +# rejected by the SSRF resolver. +LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"}) + +# Known model file extensions (frontend parity). Checked on the final filename. +ALLOWED_MODEL_EXTENSIONS = ( + ".safetensors", + ".sft", + ".ckpt", + ".pth", + ".pt", + ".gguf", + ".bin", +) + + +def _allowed_hosts() -> dict[str, set[str]]: + hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()} + for extra in getattr(args, "download_allowed_hosts", []) or []: + host = extra.strip().lower() + if host: + hosts.setdefault(host, set()).add("https") + return hosts + + +def is_host_allowed(host: str | None, scheme: str | None) -> bool: + """True iff ``host`` is allowlisted for ``scheme``. + + Used both for the initial URL and re-checked on every redirect hop, + so a whitelisted URL cannot 30x into an off-list host. + """ + if not host or not scheme: + return False + allowed = _allowed_hosts().get(host.lower()) + return allowed is not None and scheme.lower() in allowed + + +def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool: + if allow_any_extension: + return True + return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS) + + +def filename_extension(name: str) -> str: + """Lowercased extension (including the leading dot) of a bare filename. + + Returns ``""`` when there is no extension. A leading-dot name + (``.safetensors``) is treated as having no extension (all stem), matching + ``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files. + """ + base = name.replace("\\", "/").rsplit("/", 1)[-1] + dot = base.rfind(".") + if dot <= 0: + return "" + return base[dot:].lower() + + +def is_allowed_extension_name(name: str) -> bool: + """True iff ``name`` ends in one of the known model extensions.""" + return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS) + + +def is_host_allowed_url(url: str) -> bool: + """True iff ``url`` parses and its host+scheme are allowlisted.""" + if not isinstance(url, str) or not url: + return False + try: + parsed = urlparse(url) + except ValueError: + return False + return is_host_allowed(parsed.hostname, parsed.scheme) + + +def url_path_extension(url: str) -> str: + """Extension of the URL *path* basename (query ignored), or ``""``.""" + try: + parsed = urlparse(url) + except ValueError: + return "" + return filename_extension(parsed.path) + + +def is_url_downloadable(url: str) -> bool: + """Coarse enqueue gate: host/scheme allowed and extension not disallowed. + + Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*), + this also admits URLs whose path carries no extension at all — e.g. a Civitai + ``/api/download/models/`` endpoint whose real filename only shows up in + the redirect target / ``Content-Disposition``. The true extension is then + resolved from the network and re-validated before the download is admitted. + A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...) + is still rejected here. + """ + if not is_host_allowed_url(url): + return False + ext = url_path_extension(url) + return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS + + +def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool: + """Check whether ``url`` is permitted as a server-side download source.""" + if not isinstance(url, str) or not url: + return False + try: + parsed = urlparse(url) + except ValueError: + return False + if not is_host_allowed(parsed.hostname, parsed.scheme): + return False + return has_allowed_extension(parsed.path, allow_any_extension) diff --git a/app/model_downloader/security/paths.py b/app/model_downloader/security/paths.py new file mode 100644 index 000000000..6b483a42a --- /dev/null +++ b/app/model_downloader/security/paths.py @@ -0,0 +1,132 @@ +"""Path resolution + traversal safety for downloads. + +A ``model_id`` is a *relative destination path* of the form +``/`` (e.g. ``loras/my_lora.safetensors``). This module +turns one into an absolute on-disk path under one of ComfyUI's registered +model folders, rejecting unknown folders, path traversal, and symlink escape. +This is the only thing that composes destination paths, so the engine never +touches user-supplied path strings directly. +""" + +from __future__ import annotations + +import os +import re +from typing import Iterator, Optional + +import folder_paths + +from app.model_downloader.constants import TMP_SUFFIX +from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS + +# A model_id component is a single path segment of safe characters — no slashes, +# no "..", no leading dots that could escape the target directory. +_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") + + +class InvalidModelId(ValueError): + """Raised when a model_id is malformed or names an unknown model folder.""" + + +def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]: + """Split ``/`` and validate both components. + + Returns ``(directory, filename)``. Does not touch the filesystem. + """ + if not isinstance(model_id, str) or "/" not in model_id: + raise InvalidModelId( + f"model_id must be '/', got {model_id!r}" + ) + directory, _, filename = model_id.partition("/") + if "/" in filename or not directory or not filename: + raise InvalidModelId( + f"model_id must have exactly one '/' separator, got {model_id!r}" + ) + if not _SEGMENT_RE.match(directory): + raise InvalidModelId(f"invalid directory segment {directory!r}") + if not _SEGMENT_RE.match(filename): + raise InvalidModelId(f"invalid filename segment {filename!r}") + if not allow_any_extension and not filename.lower().endswith( + ALLOWED_MODEL_EXTENSIONS + ): + raise InvalidModelId( + f"filename must end with a known model extension " + f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}" + ) + if directory not in folder_paths.folder_names_and_paths: + raise InvalidModelId(f"unknown model folder {directory!r}") + return directory, filename + + +def apply_extension(model_id: str, ext: str) -> str: + """Return ``model_id`` with its filename forced to end in ``ext``. + + ``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename + already ends in a *known model extension* it is replaced; otherwise ``ext`` + is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and + ``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a + non-model suffix (``my.model.v2``) is treated as an extensionless stem and + ``ext`` is appended. The directory part is left untouched; validation is + still the caller's job via :func:`parse_model_id`. + """ + directory, sep, filename = model_id.partition("/") + if not sep: + return model_id # malformed; parse_model_id will reject it + low = filename.lower() + for known in ALLOWED_MODEL_EXTENSIONS: + if low.endswith(known): + filename = filename[: -len(known)] + break + return f"{directory}{sep}{filename}{ext}" + + +def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]: + """Return the absolute path of an installed model, or None if missing. + + Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``. + """ + directory, filename = parse_model_id(model_id, allow_any_extension) + return folder_paths.get_full_path(directory, filename) + + +def resolve_destination( + model_id: str, allow_any_extension: bool = False +) -> tuple[str, str]: + """Return ``(final_path, temp_path)`` for a download. + + Downloads land at the first registered path for the model's directory + (the "primary" location). ``temp_path`` is a sibling ``.part`` file that + is atomically renamed onto ``final_path`` on success. The result is + asserted to stay within the registered root (defence in depth on top of + the segment regex). + """ + directory, filename = parse_model_id(model_id, allow_any_extension) + roots = folder_paths.get_folder_paths(directory) + if not roots: + raise InvalidModelId(f"no on-disk path registered for folder {directory!r}") + root = os.path.realpath(roots[0]) + final_path = os.path.realpath(os.path.join(root, filename)) + if final_path != root and not final_path.startswith(root + os.sep): + raise InvalidModelId(f"resolved path escapes model root: {model_id!r}") + temp_path = f"{final_path}{TMP_SUFFIX}" + return final_path, temp_path + + +def iter_all_tmp_paths() -> Iterator[str]: + """Yield this subsystem's temp files under every registered model folder. + + Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep + can never delete temp files created by other tools. + """ + seen_roots: set[str] = set() + for directory in list(folder_paths.folder_names_and_paths.keys()): + for root in folder_paths.get_folder_paths(directory): + if root in seen_roots or not os.path.isdir(root): + continue + seen_roots.add(root) + try: + for entry in os.scandir(root): + if entry.is_file() and entry.name.endswith(TMP_SUFFIX): + yield entry.path + except OSError: + continue diff --git a/app/model_downloader/security/ssrf.py b/app/model_downloader/security/ssrf.py new file mode 100644 index 000000000..adc217c11 --- /dev/null +++ b/app/model_downloader/security/ssrf.py @@ -0,0 +1,163 @@ +"""SSRF / exfiltration defenses. + +Two cooperating layers: + +1. :class:`ValidatingResolver` is installed on the shared connector. Every + connection — the initial probe and every segment GET, including ones made + after a redirect — resolves its host through this resolver, which rejects + any address that lands on a private / special-use IP range. Because the + resolve and the connect happen together inside the connector, there is no + check-then-connect window for DNS rebinding to exploit. + +2. :func:`check_redirect_hop` re-validates every hop. The host allowlist gates + only the *initial* user-supplied URL (anti-SSRF for arbitrary input); + legitimate downloads from allowlisted origins redirect to presigned CDN + hosts that are deliberately NOT on the allowlist (HF -> + ``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are + instead screened for scheme, embedded credentials, and — via the resolver + above — private IPs. Credentials are only ever attached when a hop's host + exactly matches a stored credential, so they are dropped on the CDN hop. + Loopback (the "download a local model" feature) is exempt from IP filtering + only for the initial URL: a *redirect* may never target a loopback host or + a blocked IP-literal, which the resolver alone can't enforce (it exempts + loopback literals and never sees IP literals through DNS). +""" + +from __future__ import annotations + +import ipaddress +import socket +from urllib.parse import urlparse + +from aiohttp.abc import AbstractResolver +from aiohttp.resolver import DefaultResolver + +from app.model_downloader.security.allowlist import LOOPBACK_HOSTS + +# Cap the redirect chain length a hop may use. +MAX_REDIRECTS = 5 + + +class SSRFError(Exception): + """A hop failed an SSRF / allowlist check.""" + + +def is_scheme_allowed(scheme: str | None, host: str | None) -> bool: + """True iff ``scheme`` is permitted for ``host`` on a download hop. + + https is always allowed; plain http only for loopback/approved dev hosts. + """ + if not scheme: + return False + scheme = scheme.lower() + if scheme == "https": + return True + if scheme == "http": + return bool(host) and host.lower() in LOOPBACK_HOSTS + return False + + +def is_blocked_ip(ip_str: str) -> bool: + """True for any address we refuse to connect to. + + Covers loopback, link-local (incl. 169.254.169.254 cloud metadata), + RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::), + multicast and other reserved ranges. + """ + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + return True # unparseable -> refuse + # On CPython before the gh-113171 fix (backported to 3.12.4/3.11.9/ + # 3.10.14/3.9.19) the is_* properties don't see through IPv4-mapped IPv6 + # (e.g. ::ffff:169.254.169.254), so resolve and re-check the embedded IPv4 + # to keep mapped metadata/private addresses from slipping past the filter. + mapped = getattr(ip, "ipv4_mapped", None) + if mapped is not None: + ip = mapped + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + or ip.is_unspecified + ) + + +class ValidatingResolver(AbstractResolver): + """Delegating resolver that drops blocked IPs from every resolution. + + If a hostname resolves only to blocked addresses, the connection fails + closed with an :class:`OSError`, which aiohttp surfaces as a connection + error to the caller. + """ + + def __init__(self) -> None: + self._inner = DefaultResolver() + + async def resolve(self, host, port=0, family=socket.AF_INET): + infos = await self._inner.resolve(host, port, family) + # localhost/127.0.0.1 are an explicit, opt-in allowlist feature. + if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS: + return infos + safe = [info for info in infos if not is_blocked_ip(info["host"])] + if not safe: + raise OSError( + f"refusing to connect to {host!r}: resolves only to " + f"private/special-use addresses" + ) + return safe + + async def close(self) -> None: + await self._inner.close() + + +def check_redirect_hop(url: str, *, is_initial_url: bool = False) -> str: + """Validate one hop's URL. + + Returns the URL unchanged on success; raises :class:`SSRFError` otherwise. + Requires https for external hosts (http only for loopback/approved dev + hosts) and forbids credentials-in-URL. The host is NOT re-checked against + the allowlist (CDN redirect targets are off-list by design); credential + leakage is prevented by exact host matching at attach time, and the landing + filename's extension is gated separately by the caller. + + Loopback/blocked-IP screening: the connector's resolver filters resolvable + hostnames but exempts literal loopback hosts (``localhost``/``127.0.0.1``/ + ``::1``) and never sees IP literals through DNS. That loopback exemption is + legitimate only for the *initial* user-supplied URL (``is_initial_url``); + on a redirect hop we reject loopback hosts and any blocked IP-literal here, + so a 30x can't steer a server-side GET at loopback/internal services. + """ + try: + parsed = urlparse(url) + except ValueError as e: + raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e + host = parsed.hostname + if not host: + raise SSRFError(f"redirect URL has no host: {url!r}") + if not is_scheme_allowed(parsed.scheme, host): + raise SSRFError( + f"redirect to disallowed scheme {parsed.scheme!r} for host " + f"{host!r} (https required for external hosts)" + ) + if parsed.username or parsed.password: + raise SSRFError("credentials-in-URL are not allowed") + host_is_loopback = host.lower() in LOOPBACK_HOSTS + if not is_initial_url and host_is_loopback: + raise SSRFError(f"redirect to loopback host {host!r} is not allowed") + # IP-literal targets never go through DNS, so the connector's resolver can't + # screen them — check them directly. The only blocked IP allowed through is + # a loopback literal on the initial URL (handled by the exemption above). + try: + ipaddress.ip_address(host) + except ValueError: + is_ip_literal = False + else: + is_ip_literal = True + if is_ip_literal and is_blocked_ip(host) and not ( + is_initial_url and host_is_loopback + ): + raise SSRFError(f"redirect to blocked internal address {host!r}") + return url diff --git a/app/model_downloader/verify/checksum.py b/app/model_downloader/verify/checksum.py new file mode 100644 index 000000000..335428c81 --- /dev/null +++ b/app/model_downloader/verify/checksum.py @@ -0,0 +1,49 @@ +"""Hub-checksum verification = SHA256. + +Only used to confirm a download matches a *provided* ``expected_sha256``. It +is NOT the dedup key (that is blake3, owned by the assets system). The full +sequential read happens at most once, here, only when a checksum was supplied. +""" + +from __future__ import annotations + +import hashlib +from typing import Callable, Optional + +_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +class ChecksumError(Exception): + """The computed SHA256 did not match the expected value.""" + + +def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]: + """Stream the file and return its lowercase hex SHA256. + + Returns ``None`` if interrupted via ``interrupt_check``. + """ + h = hashlib.sha256() + with open(path, "rb") as f: + while True: + if interrupt_check is not None and interrupt_check(): + return None + chunk = f.read(_CHUNK) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + +def verify_sha256( + path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None +) -> None: + """Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``.""" + actual = sha256_file(path, interrupt_check) + if actual is None: + return # interrupted; caller will re-verify on resume + if actual.lower() != expected.lower(): + raise ChecksumError( + f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}" + ) diff --git a/app/model_downloader/verify/dedup.py b/app/model_downloader/verify/dedup.py new file mode 100644 index 000000000..8034a2ba1 --- /dev/null +++ b/app/model_downloader/verify/dedup.py @@ -0,0 +1,53 @@ +"""Dedup + catalog handoff — reuse the assets system. + +We do NOT build a parallel indexer. "Do I already have it?" is answered by +``resolve_existing`` (path) at enqueue time and, where a hash is known, by the +assets blake3 catalog. After a completed download we register the file +through the assets ingest path so it is cataloged and (eventually) hashed by +the existing enrichment worker. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + + +def _register_sync(abs_path: str) -> Optional[str]: + """Register a finished file into the assets catalog. Returns asset hash.""" + try: + from app.assets.services.ingest import register_file_in_place + except Exception as e: # assets package import failure — non-fatal + logging.debug("[model_downloader] assets ingest unavailable: %s", e) + return None + try: + result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[]) + return result.asset.hash if result and result.asset else None + except Exception as e: + # The file is already safely on disk; cataloging is best-effort. + logging.warning( + "[model_downloader] could not register %s into assets catalog: %s", + abs_path, e, + ) + return None + + +async def register_completed(abs_path: str) -> Optional[str]: + """Catalog a completed download via the assets system (off the event loop).""" + return await asyncio.to_thread(_register_sync, abs_path) + + +def _find_by_hash_sync(blake3_hex: str) -> Optional[str]: + try: + from app.assets.services.asset_management import get_asset_by_hash + except Exception: + return None + asset = get_asset_by_hash("blake3:" + blake3_hex) + return asset.hash if asset is not None else None + + +async def find_existing_by_hash(blake3_hex: str) -> Optional[str]: + """Pure DB lookup — never triggers hashing on the hot path.""" + return await asyncio.to_thread(_find_by_hash_sync, blake3_hex) diff --git a/app/model_downloader/verify/structural.py b/app/model_downloader/verify/structural.py new file mode 100644 index 000000000..2d8fca924 --- /dev/null +++ b/app/model_downloader/verify/structural.py @@ -0,0 +1,86 @@ +"""Cheap structural validation, no full read. + +For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries +the tensor table and the byte length of the data region. We assert +``file_size == 8 + header_len + data_region_len``. This detects truncation +and most corruption for free, before any crypto hashing. Other extensions +have no cheap structural check and pass through. +""" + +from __future__ import annotations + +import json +import os +import struct +from typing import Optional + +_SAFETENSORS_EXTS = (".safetensors", ".sft") +# A sane upper bound so a corrupt header length can't make us read gigabytes. +_MAX_HEADER_BYTES = 100 * 1024 * 1024 + + +class StructuralError(Exception): + """The file failed its structural integrity check.""" + + +def validate(path: str, name_hint: Optional[str] = None) -> None: + """Validate the file at ``path``. Raises :class:`StructuralError` on failure. + + The file format is detected from ``name_hint`` when provided, otherwise from + ``path``. Callers that download into a temp file with an opaque suffix (e.g. + ``*.comfy-download.part``) must pass the final destination name as + ``name_hint`` so the format check is not silently skipped. + """ + lower = (name_hint or path).lower() + if lower.endswith(_SAFETENSORS_EXTS): + _validate_safetensors(path) + # No structural check for other formats; the size + (optional) checksum + # gates in the engine cover those. + + +def _validate_safetensors(path: str) -> None: + file_size = os.path.getsize(path) + if file_size < 8: + raise StructuralError(f"file too small to be safetensors ({file_size} bytes)") + with open(path, "rb") as f: + header_len = struct.unpack(" _MAX_HEADER_BYTES: + raise StructuralError(f"implausible safetensors header length {header_len}") + if 8 + header_len > file_size: + raise StructuralError("safetensors header extends past end of file") + try: + header = json.loads(f.read(header_len).decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as e: + raise StructuralError(f"safetensors header is not valid JSON: {e}") from e + + if not isinstance(header, dict): + raise StructuralError("safetensors header is not a JSON object") + + data_len = 0 + for name, entry in header.items(): + if name == "__metadata__": + continue + if not isinstance(entry, dict) or "data_offsets" not in entry: + raise StructuralError(f"tensor {name!r} missing data_offsets") + offsets = entry["data_offsets"] + if not (isinstance(offsets, list) and len(offsets) == 2): + raise StructuralError(f"tensor {name!r} has malformed data_offsets") + begin, end = offsets + # bool is an int subclass; reject it explicitly to avoid True/False offsets. + if ( + not isinstance(begin, int) + or not isinstance(end, int) + or isinstance(begin, bool) + or isinstance(end, bool) + or begin < 0 + or end < begin + ): + raise StructuralError(f"tensor {name!r} has malformed data_offsets") + data_len = max(data_len, end) + + expected = 8 + header_len + data_len + if file_size != expected: + raise StructuralError( + f"size mismatch: file is {file_size} bytes, header implies {expected} " + f"(8 + {header_len} header + {data_len} data)" + ) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 4bef096fb..ce5a27c81 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -33,6 +33,28 @@ class EnumAction(argparse.Action): setattr(namespace, self.dest, value) +def _positive_int(value: str) -> int: + """argparse type that rejects zero and negative integers.""" + try: + ivalue = int(value) + except ValueError: + raise argparse.ArgumentTypeError(f"{value!r} is not an integer") + if ivalue <= 0: + raise argparse.ArgumentTypeError(f"{value!r} must be a positive integer (> 0)") + return ivalue + + +def _non_negative_int(value: str) -> int: + """argparse type that rejects negatives but allows zero (a disable sentinel).""" + try: + ivalue = int(value) + except ValueError: + raise argparse.ArgumentTypeError(f"{value!r} is not an integer") + if ivalue < 0: + raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)") + return ivalue + + parser = argparse.ArgumentParser() parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") @@ -244,6 +266,15 @@ parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button") parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.") +# ----- Model download manager (PRD: docs/prd-download-manager.md) ----- +parser.add_argument("--download-segments", type=_positive_int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).") +parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).") +parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).") +parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).") +parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).") +parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.") +parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0f30608a9..4ce977c20 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -104,6 +104,7 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = { "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, "assets": args.enable_assets, + "server_side_model_downloads": True, } # CLI-provided flags cannot overwrite core flags diff --git a/openapi.yaml b/openapi.yaml index c6a8621cc..67de1704c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -230,6 +230,93 @@ components: - base_version - workflow_json type: object + DownloadEnqueueRequest: + description: Request body for enqueuing a server-side model download. + properties: + allow_any_extension: + default: false + description: Permit a non-model file extension (default only allows known model extensions). + type: boolean + credential_id: + description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match. + nullable: true + type: string + expected_sha256: + description: Optional hub-provided SHA256 to verify the completed file against (fail-closed). + nullable: true + type: string + model_id: + description: Destination as "/", resolving to a registered model folder (e.g. "loras/my_lora.safetensors"). + type: string + priority: + default: 0 + description: Scheduling priority; higher is admitted first. + type: integer + url: + description: Source URL; must be on the allowlist (host + scheme + extension). + type: string + required: + - url + - model_id + type: object + DownloadStatus: + description: Current state and live progress of a single download. + properties: + bytes_done: + type: integer + created_at: + type: integer + download_id: + format: uuid + type: string + error: + nullable: true + type: string + eta_seconds: + nullable: true + type: number + model_id: + type: string + priority: + type: integer + progress: + description: Fraction in [0,1]; null until total size is known. + nullable: true + type: number + segments: + description: Per-segment progress (segmented downloads only). + items: + properties: + bytes_done: + type: integer + idx: + type: integer + length: + type: integer + type: object + nullable: true + type: array + speed_bps: + nullable: true + type: number + status: + enum: + - queued + - active + - paused + - verifying + - completed + - failed + - cancelled + type: string + total_bytes: + nullable: true + type: integer + updated_at: + type: integer + url: + type: string + type: object ErrorResponse: description: Standard error response with a machine-readable code and human-readable message. properties: @@ -511,6 +598,78 @@ components: required: - history type: object + HostCredentialUpsert: + description: Request body for upserting a per-host credential. The secret is write-only. + properties: + auth_scheme: + default: bearer + description: How the secret is attached to requests. + enum: + - bearer + - header + - query + type: string + enabled: + default: true + type: boolean + header_name: + description: Header name when auth_scheme=header (defaults to Authorization). + nullable: true + type: string + host: + description: Normalized hostname the key applies to (e.g. "civitai.com"). + type: string + label: + description: User-friendly name for display. + nullable: true + type: string + match_subdomains: + default: false + description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs). + type: boolean + query_param: + description: Query parameter name when auth_scheme=query. + nullable: true + type: string + secret: + description: The API key. Write-only — never returned by any endpoint. + type: string + required: + - host + - secret + type: object + HostCredentialView: + description: Masked, API-safe view of a stored credential. Never includes the secret. + properties: + auth_scheme: + type: string + created_at: + type: integer + enabled: + type: boolean + header_name: + nullable: true + type: string + host: + type: string + id: + format: uuid + type: string + label: + nullable: true + type: string + match_subdomains: + type: boolean + query_param: + nullable: true + type: string + secret_last4: + description: Last 4 characters of the secret, for masked display only. + nullable: true + type: string + updated_at: + type: integer + type: object JobCancelResponse: description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops. properties: @@ -2350,6 +2509,391 @@ paths: summary: Get tag histogram for filtered assets tags: - file + /api/download: + get: + description: List all known downloads (queued, active, paused, and terminal) with live progress. + operationId: listDownloads + responses: + "200": + content: + application/json: + schema: + properties: + downloads: + items: + $ref: '#/components/schemas/DownloadStatus' + type: array + type: object + description: List of downloads + summary: List downloads + tags: + - download + /api/download/availability: + post: + description: | + Bulk per-id availability for a set of model_ids declared in a workflow. + Returns whether each model is available on disk, currently downloading + (with progress), or missing, plus whether its URL is on the allowlist. + operationId: getModelsAvailability + requestBody: + content: + application/json: + schema: + properties: + models: + additionalProperties: + type: string + description: Map of "/" model_id to its declared source URL. + type: object + type: object + responses: + "200": + content: + application/json: + schema: + properties: + models: + additionalProperties: true + type: object + type: object + description: Per-id availability map + summary: Bulk model availability + status + tags: + - download + /api/download/clear: + post: + description: | + Delete all terminal downloads (completed, failed, cancelled) from history + in one transaction, so the cleared history persists across reloads. Live + downloads (queued, active, paused, verifying) are skipped. Finished model + files on disk are never removed; only leftover .part temp files are cleaned up. + operationId: clearDownloads + responses: + "200": + content: + application/json: + schema: + properties: + deleted: + description: Number of history rows removed. + type: integer + type: object + description: History cleared + summary: Clear terminal downloads from history + tags: + - download + /api/download/credentials: + get: + description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label). + operationId: listDownloadCredentials + responses: + "200": + content: + application/json: + schema: + properties: + credentials: + items: + $ref: '#/components/schemas/HostCredentialView' + type: array + type: object + description: Masked credential list + summary: List host credentials (masked) + tags: + - download + post: + description: | + Upsert (by host) a per-host API key used to authenticate downloads. + The secret is write-only: it is stored once here and never returned by any endpoint. + operationId: upsertDownloadCredential + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/HostCredentialUpsert' + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/HostCredentialView' + description: Credential stored (masked view returned) + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid credential + summary: Upsert a host credential + tags: + - download + /api/download/credentials/{id}: + delete: + description: Delete a stored host credential. + operationId: deleteDownloadCredential + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + properties: + deleted: + type: boolean + type: object + description: Deleted + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such credential + summary: Delete a host credential + tags: + - download + get: + description: Get a single host credential (masked; never includes the secret). + operationId: getDownloadCredential + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HostCredentialView' + description: Masked credential + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such credential + summary: Get a host credential (masked) + tags: + - download + /api/download/enqueue: + post: + description: | + Enqueue a server-side model download. The URL must be on the allowlist + (host + scheme + extension) and the model_id must be "/" + resolving to a registered model folder. Returns immediately; track progress + via GET /api/download/{id} or the "download_progress" websocket event. + operationId: enqueueDownload + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DownloadEnqueueRequest' + responses: + "202": + content: + application/json: + schema: + properties: + accepted: + type: boolean + download_id: + format: uuid + type: string + type: object + description: Download accepted and queued + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request (bad URL, model_id, or not allowlisted) + "409": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Already on disk or already downloading + summary: Enqueue a model download + tags: + - download + /api/download/{id}: + delete: + description: | + Delete a single terminal download from history so it stays gone across + reloads. Refuses (409) to delete a live download (queued, active, paused, + verifying) — cancel it first. The finished model file on disk is never + removed; only a leftover .part temp file is cleaned up. + operationId: deleteDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + deleted: + type: boolean + type: object + description: Deleted + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such download + "409": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Download is still in progress + summary: Delete a download from history + tags: + - download + get: + description: Get the current status + progress of a single download. + operationId: getDownloadStatus + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/DownloadStatus' + description: Download status + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such download + summary: Get download status + tags: + - download + /api/download/{id}/cancel: + post: + description: Cancel a download. The partial file is removed. + operationId: cancelDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Cancelled + summary: Cancel a download + tags: + - download + /api/download/{id}/pause: + post: + description: Pause a download. The partial file and per-segment offsets are retained for resume. + operationId: pauseDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Paused + summary: Pause a download + tags: + - download + /api/download/{id}/priority: + post: + description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees. + operationId: setDownloadPriority + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + requestBody: + content: + application/json: + schema: + properties: + priority: + type: integer + required: + - priority + type: object + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Priority updated + summary: Set download priority + tags: + - download + /api/download/{id}/resume: + post: + description: Resume a paused (or failed) download from its persisted offsets. + operationId: resumeDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Resumed + summary: Resume a download + tags: + - download /api/embeddings: get: description: Returns the list of text-encoder embeddings available on disk. @@ -5103,3 +5647,5 @@ tags: name: queue - description: Job lifecycle queries name: job + - description: Model download management + name: download diff --git a/server.py b/server.py index 361850f38..84e2c7575 100644 --- a/server.py +++ b/server.py @@ -45,6 +45,8 @@ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal from app.assets.seeder import asset_seeder from app.assets.api.routes import register_assets_routes +from app.model_downloader.api.routes import register_routes as register_model_downloader_routes +from app.model_downloader.manager import DOWNLOAD_MANAGER from app.assets.services.ingest import register_file_in_place from app.assets.services.asset_management import resolve_hash_to_path @@ -256,6 +258,7 @@ class PromptServer(): else: register_assets_routes(self.app) asset_seeder.disable() + register_model_downloader_routes(self.app) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -1182,6 +1185,29 @@ class PromptServer(): async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout) + await self._setup_model_downloader() + + async def _setup_model_downloader(self): + """Start the download manager: push progress over the websocket and + resume any downloads interrupted by a previous run.""" + def _notify(download_id: str) -> None: + try: + view = DOWNLOAD_MANAGER.status_sync(download_id) + if view is not None: + # Drop the url field before broadcasting: the redacted URL + # (scheme + host + path) should not leak to every connected + # websocket client. download_id / model_id are sufficient to + # correlate progress on the frontend. + broadcast = {k: v for k, v in view.items() if k != "url"} + self.send_sync("download_progress", broadcast) + except Exception: + logging.debug("download progress notify failed", exc_info=True) + + DOWNLOAD_MANAGER.set_notify(_notify) + try: + await DOWNLOAD_MANAGER.start() + except Exception as e: + logging.warning("Failed to start model download manager: %s", e) def add_routes(self): self.user_manager.add_routes(self.routes) diff --git a/tests-unit/model_downloader_test/conftest.py b/tests-unit/model_downloader_test/conftest.py new file mode 100644 index 000000000..33165294b --- /dev/null +++ b/tests-unit/model_downloader_test/conftest.py @@ -0,0 +1,90 @@ +"""Shared fixtures for the model download manager tests. + +These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is +initialized once, a temp model folder is registered with ``folder_paths``, and +the shared aiohttp session is reset between tests so each async test gets a +session bound to its own event loop. +""" + +from __future__ import annotations + +import asyncio +import os +import tempfile + +import pytest + + +def _drain_scheduler_tasks(scheduler) -> None: + """Cancel and await live scheduler tasks so none outlive the test. + + Uses the actual task handles rather than only clearing ``_tasks``: each + per-test event loop is created by ``asyncio.run``, so a task left behind by + a crashed/aborted test would otherwise keep its coroutine alive. We cancel + every live task and, when its loop is still usable, run it to completion to + let the cancellation propagate before dropping the reference. + """ + for task in list(scheduler._tasks.values()): + if task is None: + continue + loop = task.get_loop() + if task.done() or loop.is_closed(): + continue + task.cancel() + if not loop.is_running(): + try: + loop.run_until_complete(asyncio.gather(task, return_exceptions=True)) + except Exception: + pass + scheduler._tasks.clear() + + +@pytest.fixture(scope="session", autouse=True) +def _init_db(): + import app.database.db as db + from comfy.cli_args import args + + fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3") + os.close(fd) + args.database_url = f"sqlite:///{db_path}" + db.init_db() + yield + try: + os.remove(db_path) + except OSError: + pass + + +@pytest.fixture(autouse=True) +def _reset_runtime(): + """Reset module singletons that hold event-loop-bound or cross-test state.""" + import app.model_downloader.net.session as ns + from app.model_downloader.scheduler import SCHEDULER + + ns._session = None + _drain_scheduler_tasks(SCHEDULER) + SCHEDULER._jobs.clear() + SCHEDULER._backoff_until.clear() + SCHEDULER._started = False + yield + _drain_scheduler_tasks(SCHEDULER) + ns._session = None + + +@pytest.fixture +def model_root(tmp_path): + """Register a temp 'loras' model folder and return its absolute path.""" + import folder_paths + + root = tmp_path / "loras" + root.mkdir(parents=True, exist_ok=True) + saved = folder_paths.folder_names_and_paths.get("loras") + folder_paths.folder_names_and_paths["loras"] = ( + [str(root)], + {".safetensors", ".sft", ".ckpt", ".pt", ".pth"}, + ) + yield str(root) + if saved is not None: + folder_paths.folder_names_and_paths["loras"] = saved + else: + folder_paths.folder_names_and_paths.pop("loras", None) diff --git a/tests-unit/model_downloader_test/test_credentials.py b/tests-unit/model_downloader_test/test_credentials.py new file mode 100644 index 000000000..a630ad751 --- /dev/null +++ b/tests-unit/model_downloader_test/test_credentials.py @@ -0,0 +1,166 @@ +"""Unit tests for the credential store and the per-hop credential resolver. + +Covers the critical rule: a secret is only ever attached when the current +hop's host matches a stored credential, and never over a non-https hop. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from app.model_downloader.credentials import resolver +from app.model_downloader.credentials.store import ( + CREDENTIAL_STORE, + CredentialValidationError, + normalize_host, +) +from app.model_downloader.database.models import HostCredential + + +# ----- pure host normalization + matching ----- + + +@pytest.mark.parametrize( + "raw,expected", + [ + ("Civitai.com", "civitai.com"), + ("HuggingFace.co:443", "huggingface.co"), + (" Example.COM ", "example.com"), + ], +) +def test_normalize_host(raw, expected): + assert normalize_host(raw) == expected + + +def _cred(**kw) -> HostCredential: + base = dict( + id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer", + secret="SECRET", enabled=True, + ) + base.update(kw) + return HostCredential(**base) + + +def test_matches_exact_only_by_default(): + c = _cred(host="civitai.com") + assert resolver._matches(c, "civitai.com") is True + assert resolver._matches(c, "api.civitai.com") is False + assert resolver._matches(c, "evil-civitai.com") is False + + +def test_matches_subdomain_label_boundary(): + c = _cred(host="example.com", match_subdomains=True) + assert resolver._matches(c, "api.example.com") is True + assert resolver._matches(c, "example.com") is True + # not a label boundary -> no match + assert resolver._matches(c, "evil-example.com") is False + + +def test_build_auth_shapes(): + assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == { + "Authorization": "Bearer SECRET" + } + assert resolver._build_auth( + _cred(auth_scheme="header", header_name="X-Api-Key") + ).headers == {"X-Api-Key": "SECRET"} + q = resolver._build_auth(_cred(auth_scheme="query", query_param="token")) + assert q.query == {"token": "SECRET"} + assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET" + + +# ----- DB-backed store + resolver ----- + + +def test_store_upsert_is_write_only_and_masked(): + async def _run(): + view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key") + # The view never carries the secret, only the last 4. + assert not hasattr(view, "secret") + assert view.secret_last4 == "1234" + assert view.host == "civitai.com" + listed = await CREDENTIAL_STORE.list() + assert any(v.host == "civitai.com" for v in listed) + await CREDENTIAL_STORE.delete(view.id) + asyncio.run(_run()) + + +def test_query_scheme_requires_param(): + async def _run(): + with pytest.raises(CredentialValidationError): + await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query") + asyncio.run(_run()) + + +def test_resolver_never_crosses_host_boundary(): + async def _run(): + view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key") + try: + # matching host over https -> attached + auth = await resolver.resolve_auth_for_hop("huggingface.co", "https") + assert auth is not None + assert auth.headers["Authorization"] == "Bearer hf_secret_key" + # CDN redirect host -> dropped + assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None + # non-https hop -> never attached + assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None + finally: + await CREDENTIAL_STORE.delete(view.id) + asyncio.run(_run()) + + +# ----- env-based HF token fallback ----- + + +def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch): + monkeypatch.setenv("HF_TOKEN", "env_hf_token") + + async def _run(): + # exact host over https -> env token attached + auth = await resolver.resolve_auth_for_hop("huggingface.co", "https") + assert auth is not None + assert auth.headers["Authorization"] == "Bearer env_hf_token" + # non-https hop -> never attached + assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None + # CDN redirect host -> dropped (exact-host only) + assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None + asyncio.run(_run()) + + +def test_env_token_secondary_var_is_honored(monkeypatch): + monkeypatch.delenv("HF_TOKEN", raising=False) + monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token") + + async def _run(): + auth = await resolver.resolve_auth_for_hop("huggingface.co", "https") + assert auth is not None + assert auth.headers["Authorization"] == "Bearer env_hub_token" + asyncio.run(_run()) + + +def test_db_credential_takes_precedence_over_env(monkeypatch): + monkeypatch.setenv("HF_TOKEN", "env_hf_token") + + async def _run(): + view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key") + try: + auth = await resolver.resolve_auth_for_hop("huggingface.co", "https") + assert auth is not None + assert auth.headers["Authorization"] == "Bearer db_secret_key" + finally: + await CREDENTIAL_STORE.delete(view.id) + asyncio.run(_run()) + + +def test_env_token_does_not_leak_into_explicit_path(monkeypatch): + monkeypatch.setenv("HF_TOKEN", "env_hf_token") + + async def _run(): + # An explicit credential id that doesn't resolve must stay None; the env + # fallback only applies to the auto-resolve branch. + auth = await resolver.resolve_auth_for_hop( + "huggingface.co", "https", explicit_credential_id="does-not-exist" + ) + assert auth is None + asyncio.run(_run()) diff --git a/tests-unit/model_downloader_test/test_delete.py b/tests-unit/model_downloader_test/test_delete.py new file mode 100644 index 000000000..d9e675075 --- /dev/null +++ b/tests-unit/model_downloader_test/test_delete.py @@ -0,0 +1,136 @@ +"""Unit tests for ``DownloadManager.delete`` and ``DownloadManager.clear``. + +Deleting a terminal row must remove it from history for good (so it does not +reappear on the next ``list``), leave live rows untouched, and clean up any +leftover ``.part`` temp file without touching the finished model file. + +``clear()`` is the bulk variant: it removes all terminal rows atomically, skips +live ones, and returns the count of rows deleted. + +Async methods are driven via ``asyncio.run`` so no pytest-asyncio plugin is +required. +""" + +from __future__ import annotations + +import asyncio +import os + +import pytest + +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + + +def _insert(download_id: str, status: str, *, temp_path: str = "/tmp/none.part") -> None: + queries.insert_download( + { + "id": download_id, + "url": "https://huggingface.co/org/model.safetensors", + "model_id": "loras/model.safetensors", + "dest_path": "/tmp/model.safetensors", + "temp_path": temp_path, + "status": status, + "priority": 0, + } + ) + + +def test_delete_removes_terminal_row_from_history(): + _insert("done", DownloadStatus.COMPLETED) + + asyncio.run(DOWNLOAD_MANAGER.delete("done")) + + assert queries.get_download("done") is None + + +def test_delete_refuses_live_row(): + _insert("live", DownloadStatus.QUEUED) + + with pytest.raises(DownloadError) as excinfo: + asyncio.run(DOWNLOAD_MANAGER.delete("live")) + + assert excinfo.value.code == "DOWNLOAD_ACTIVE" + assert queries.get_download("live") is not None + + +def test_delete_missing_row_raises_not_found(): + with pytest.raises(DownloadError) as excinfo: + asyncio.run(DOWNLOAD_MANAGER.delete("nope")) + + assert excinfo.value.code == "NOT_FOUND" + + +def test_delete_removes_leftover_temp_file(tmp_path): + partial = tmp_path / "model.safetensors.part" + partial.write_bytes(b"partial") + _insert("failed", DownloadStatus.FAILED, temp_path=str(partial)) + + asyncio.run(DOWNLOAD_MANAGER.delete("failed")) + + assert not os.path.exists(partial) + assert queries.get_download("failed") is None + + +# ----- clear ----- + + +def test_clear_removes_all_terminal_rows(): + _insert("c-done", DownloadStatus.COMPLETED) + _insert("c-fail", DownloadStatus.FAILED) + _insert("c-canc", DownloadStatus.CANCELLED) + + deleted = asyncio.run(DOWNLOAD_MANAGER.clear()) + + assert deleted == 3 + assert queries.get_download("c-done") is None + assert queries.get_download("c-fail") is None + assert queries.get_download("c-canc") is None + + +def test_clear_skips_live_rows(): + _insert("cl-queued", DownloadStatus.QUEUED) + _insert("cl-paused", DownloadStatus.PAUSED) + _insert("cl-done", DownloadStatus.COMPLETED) + + deleted = asyncio.run(DOWNLOAD_MANAGER.clear()) + + assert deleted == 1 + assert queries.get_download("cl-queued") is not None + assert queries.get_download("cl-paused") is not None + assert queries.get_download("cl-done") is None + + +def test_clear_returns_zero_when_nothing_to_delete(): + _insert("cl-only-live", DownloadStatus.QUEUED) + + deleted = asyncio.run(DOWNLOAD_MANAGER.clear()) + + assert deleted == 0 + assert queries.get_download("cl-only-live") is not None + + +def test_clear_removes_leftover_temp_files(tmp_path): + partial = tmp_path / "clear_partial.part" + partial.write_bytes(b"partial data") + finished = tmp_path / "finished.safetensors" + finished.write_bytes(b"real model weights") + + _insert("cl-part", DownloadStatus.FAILED, temp_path=str(partial)) + # The finished file is not the temp_path; temp_path for a completed download + # no longer exists (already renamed), so use a non-existent path here to + # verify clear() tolerates a missing temp file without raising. + _insert("cl-comp", DownloadStatus.COMPLETED, temp_path=str(tmp_path / "gone.part")) + + asyncio.run(DOWNLOAD_MANAGER.clear()) + + # Leftover .part from the failed download is cleaned up. + assert not partial.exists() + # Finished model file is never touched. + assert finished.exists() + + +def test_clear_empty_db_returns_zero(): + deleted = asyncio.run(DOWNLOAD_MANAGER.clear()) + assert deleted == 0 diff --git a/tests-unit/model_downloader_test/test_engine_integration.py b/tests-unit/model_downloader_test/test_engine_integration.py new file mode 100644 index 000000000..435c7f4c7 --- /dev/null +++ b/tests-unit/model_downloader_test/test_engine_integration.py @@ -0,0 +1,637 @@ +"""Integration tests for the download engine against a local aiohttp server. + +Covers single-stream and segmented transfers, deterministic resume from a +partial file, and cancel rollback. Async tests are driven via ``asyncio.run`` +so no pytest-asyncio plugin is required. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import struct +import uuid + +import pytest +from aiohttp import web + +from comfy.cli_args import args +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.engine.job import DownloadJob, JobSpec +from app.model_downloader.net.session import close_session +from app.model_downloader.security import paths + +PAYLOAD_ETAG = '"v1"' + + +def _payload(n: int) -> bytes: + return bytes((i * 37 + 11) % 256 for i in range(n)) + + +def _safetensors_payload(total: int) -> bytes: + """A structurally valid ``.safetensors`` blob of exactly ``total`` bytes. + + Success-path tests download to ``.safetensors`` destinations, which the + engine now structurally validates before the atomic rename, so their + payloads must parse as real safetensors (header length + JSON header + + data region whose size matches the declared ``data_offsets``). + """ + def _header(data_len: int) -> bytes: + return json.dumps( + {"w": {"dtype": "U8", "shape": [data_len], "data_offsets": [0, data_len]}} + ).encode("utf-8") + + # The header's byte length depends on the digit count of ``data_len``, so + # iterate until ``total == 8 + len(header) + data_len`` is self-consistent. + data_len = total - 8 - len(_header(total)) + for _ in range(8): + header = _header(data_len) + new_data_len = total - 8 - len(header) + if new_data_len == data_len: + break + data_len = new_data_len + assert data_len >= 0, "total too small for a safetensors payload" + header = _header(data_len) + body = bytes((i * 37 + 11) % 256 for i in range(data_len)) + return struct.pack(" web.Response: + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + chunk = payload[start : end + 1] + return web.Response( + status=206, + body=chunk, + headers={ + "Content-Range": f"bytes {start}-{end}/{len(payload)}", + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + }, + ) + return web.Response( + status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} + ) + + return handler + + +def _content_disposition_handler(payload: bytes, filename: str): + """A range-capable server that only reveals its filename via a header. + + Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no + extension, and the real filename (hence extension) lives in the response + ``Content-Disposition`` header. + """ + + async def handler(request: web.Request) -> web.Response: + headers = { + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + "Content-Disposition": f'attachment; filename="{filename}"', + } + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + chunk = payload[start : end + 1] + return web.Response( + status=206, + body=chunk, + headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"}, + ) + return web.Response(status=200, body=payload, headers=headers) + + return handler + + +def _noranges_handler(payload: bytes): + async def handler(request: web.Request) -> web.Response: + # Always full body, never advertises Accept-Ranges -> single-stream. + return web.Response(status=200, body=payload) + + return handler + + +def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01): + async def handler(request: web.Request) -> web.StreamResponse: + resp = web.StreamResponse( + status=200, headers={"Content-Length": str(len(payload))} + ) + await resp.prepare(request) + for i in range(0, len(payload), chunk): + await resp.write(payload[i : i + chunk]) + await asyncio.sleep(delay) + await resp.write_eof() + return resp + + return handler + + +def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024): + """A non-conforming 206 server that returns MORE than the requested range.""" + + async def handler(request: web.Request) -> web.Response: + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + # Maliciously overrun: append extra bytes past the requested end. + body = payload[start : end + 1] + bytes(extra) + return web.Response( + status=206, + body=body, + headers={ + "Content-Range": f"bytes {start}-{end}/{len(payload)}", + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + }, + ) + return web.Response( + status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} + ) + + return handler + + +def _short_range_handler(payload: bytes, drop: int = 64 * 1024): + """A 206 server that returns fewer bytes than requested for later segments. + + Simulates a server cleanly closing a range connection early. The response + is internally consistent (Content-Length matches the short body), so the + client sees no error and the segment just ends short, leaving a zero-filled + hole in the preallocated file. + """ + + async def handler(request: web.Request) -> web.Response: + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + chunk = payload[start : end + 1] + if start > 0 and len(chunk) > drop: + chunk = chunk[:-drop] # truncate a non-first segment + return web.Response( + status=206, + body=chunk, + headers={ + "Content-Range": f"bytes {start}-{end}/{len(payload)}", + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + }, + ) + return web.Response( + status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} + ) + + return handler + + +def _unbounded_handler(total: int, chunk: int = 16384): + """A 200 stream with no Content-Length / Accept-Ranges (unknown length).""" + + async def handler(request: web.Request) -> web.StreamResponse: + resp = web.StreamResponse(status=200) + await resp.prepare(request) + sent = 0 + while sent < total: + await resp.write(bytes(min(chunk, total - sent))) + sent += chunk + await resp.write_eof() + return resp + + return handler + + +async def _serve(handler): + app = web.Application() + app.router.add_route("*", "/{name:.*}", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] + return runner, port + + +def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]: + final_path, temp_path = paths.resolve_destination(model_id) + download_id = str(uuid.uuid4()) + queries.insert_download( + { + "id": download_id, + "url": url, + "model_id": model_id, + "dest_path": final_path, + "temp_path": temp_path, + "status": status, + } + ) + return download_id, final_path, temp_path + + +# ----- single-stream ----- + + +def test_single_stream_download(model_root): + payload = _safetensors_payload(300_000) + + async def _run(): + await close_session() + runner, port = await _serve(_noranges_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, _temp = _insert("loras/single.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/single.safetensors", + dest_path=final_path, temp_path=_temp, + )) + status = await job.run() + assert status == DownloadStatus.COMPLETED, queries.get_download(did).error + assert os.path.exists(final_path) + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- segmented ----- + + +def test_segmented_download(model_root): + payload = _safetensors_payload(4 * 1024 * 1024) # 4 MiB -> multiple segments + + async def _run(): + await close_session() + runner, port = await _serve(_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/seg.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/seg.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.COMPLETED, queries.get_download(did).error + assert open(final_path, "rb").read() == payload + # More than one segment row was planned. + assert len(queries.list_segments(did)) > 1 + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- deterministic resume from a partial file ----- + + +def test_resume_from_partial(model_root): + payload = _safetensors_payload(512 * 1024) # < 1 MiB -> single segment + + async def _run(): + await close_session() + runner, port = await _serve(_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/resume.safetensors", url) + # Simulate a prior partial: first 200 KiB already written, offset persisted. + prefix = 200 * 1024 + os.makedirs(os.path.dirname(temp), exist_ok=True) + with open(temp, "wb") as f: + f.write(payload[:prefix]) + queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG) + + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/resume.safetensors", + dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG, + )) + status = await job.run() + assert status == DownloadStatus.COMPLETED, queries.get_download(did).error + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- cancel rollback ----- + + +def test_cancel_rollback(model_root, monkeypatch): + monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False) + payload = _payload(1024 * 1024) + + async def _run(): + await close_session() + runner, port = await _serve(_slow_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/cancel.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/cancel.safetensors", + dest_path=final_path, temp_path=temp, + )) + task = asyncio.ensure_future(job.run()) + # Wait until some bytes have been written, then cancel. + for _ in range(200): + await asyncio.sleep(0.01) + if job.state.bytes_done > 0: + break + job.request_cancel() + status = await task + assert status == DownloadStatus.CANCELLED + assert not os.path.exists(temp) + assert not os.path.exists(final_path) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- size-bound enforcement (malicious / non-conforming hosts) ----- + + +def test_segment_overflow_aborts(model_root): + """A 206 returning more than the requested range must not overrun.""" + payload = _payload(4 * 1024 * 1024) # large enough to segment + + async def _run(): + await close_session() + runner, port = await _serve(_overflow_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/overflow.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/overflow.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_short_segment_fails_closed(model_root): + """A segment that ends short must fail, not be accepted as complete. + + The file is preallocated to total_bytes, so the on-disk size still equals + total even with a zero-filled hole; completeness must be judged per-segment. + """ + payload = _safetensors_payload(4 * 1024 * 1024) # large enough to segment + + async def _run(): + await close_session() + runner, port = await _serve(_short_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/short.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/short.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED, queries.get_download(did).error + assert "incomplete" in (queries.get_download(did).error or "") + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_structural_validation_rejects_corrupt(model_root): + """A correctly sized but structurally invalid file fails closed (not retried). + + Regression for the dead structural gate: validation must key off the + destination extension, not the ``.part`` temp suffix. + """ + payload = _payload(300_000) # right size, but not a valid safetensors blob + + async def _run(): + await close_session() + runner, port = await _serve(_noranges_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/corrupt.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/corrupt.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED, queries.get_download(did).error + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + # Failed closed at first attempt, not re-queued as retryable. + assert queries.get_download(did).attempts == 0 + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_rejects_oversized_known_download(model_root, monkeypatch): + """A file whose advertised size exceeds the cap is rejected at probe.""" + monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False) + payload = _payload(300_000) + + async def _run(): + await close_session() + runner, port = await _serve(_noranges_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/toobig.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/toobig.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED + assert not os.path.exists(final_path) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch): + """An unbounded unknown-length stream is capped by --download-max-bytes.""" + monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False) + monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False) + + async def _run(): + await close_session() + runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/unbounded.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/unbounded.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- manager + scheduler end-to-end ----- + + +def test_manager_enqueue_to_completion(model_root): + payload = _safetensors_payload(2 * 1024 * 1024) + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER + + runner, port = await _serve(_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors") + # Wait for completion. + final_path, _ = paths.resolve_destination("loras/e2e.safetensors") + for _ in range(500): + await asyncio.sleep(0.02) + row = queries.get_download(did) + if row.status in DownloadStatus.TERMINAL: + break + row = queries.get_download(did) + assert row.status == DownloadStatus.COMPLETED, row.error + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_manager_rejects_disallowed_url(model_root): + async def _run(): + from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + + with pytest.raises(DownloadError) as ei: + await DOWNLOAD_MANAGER.enqueue( + "https://evil.example.com/x.safetensors", "loras/bad.safetensors" + ) + assert ei.value.code == "URL_NOT_ALLOWED" + + asyncio.run(_run()) + + +def test_manager_resolves_extensionless_url(model_root): + """An allowlisted URL with no extension in its path is resolved from the + response, and the stored file adopts the resolved extension.""" + payload = _safetensors_payload(1 * 1024 * 1024) + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER + + runner, port = await _serve( + _content_disposition_handler(payload, "RealModel.safetensors") + ) + try: + # No extension in the path (Civitai-style) and none in the model_id. + url = f"http://127.0.0.1:{port}/api/download/models/12345" + did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model") + + row = queries.get_download(did) + # The resolved extension was appended to the model_id + destination. + assert row.model_id == "loras/my_civitai_model.safetensors" + assert row.dest_path.endswith("my_civitai_model.safetensors") + + final_path, _ = paths.resolve_destination( + "loras/my_civitai_model.safetensors" + ) + for _ in range(500): + await asyncio.sleep(0.02) + row = queries.get_download(did) + if row.status in DownloadStatus.TERMINAL: + break + row = queries.get_download(did) + assert row.status == DownloadStatus.COMPLETED, row.error + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_manager_overrides_extension_from_resolution(model_root): + """A model_id carrying a different known extension is corrected to match + the resolved URL's extension.""" + payload = _safetensors_payload(256 * 1024) + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER + + runner, port = await _serve( + _content_disposition_handler(payload, "weights.safetensors") + ) + try: + url = f"http://127.0.0.1:{port}/api/download/models/777" + # Caller guessed .ckpt; resolution says .safetensors -> corrected. + did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt") + row = queries.get_download(did) + assert row.model_id == "loras/guessed.safetensors" + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_manager_rejects_non_model_resolution(model_root): + """A URL that resolves to a non-model file is rejected, not downloaded.""" + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + + runner, port = await _serve( + _content_disposition_handler(b"not a model", "installer.zip") + ) + try: + url = f"http://127.0.0.1:{port}/api/download/models/999" + with pytest.raises(DownloadError) as ei: + await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever") + assert ei.value.code == "URL_NOT_ALLOWED" + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) diff --git a/tests-unit/model_downloader_test/test_planner_structural.py b/tests-unit/model_downloader_test/test_planner_structural.py new file mode 100644 index 000000000..139cf4094 --- /dev/null +++ b/tests-unit/model_downloader_test/test_planner_structural.py @@ -0,0 +1,81 @@ +"""Unit tests for the segment planner and structural safetensors validation.""" + +from __future__ import annotations + +import json +import struct + +import pytest + +from app.model_downloader.engine.planner import ( + effective_segment_count, + plan_segments, +) +from app.model_downloader.verify import structural + + +# ----- planner ----- + + +def test_plan_segments_covers_full_range_contiguously(): + total = 1000 + plans = plan_segments(total, 4) + assert len(plans) == 4 + assert plans[0].start == 0 + assert plans[-1].end == total - 1 + # contiguous, no gaps/overlaps + for a, b in zip(plans, plans[1:]): + assert b.start == a.end + 1 + assert sum(p.length for p in plans) == total + + +def test_effective_segment_count_falls_back_to_single(): + # No range support -> single + assert effective_segment_count(10_000_000, False, 8) == 1 + # Unknown size -> single + assert effective_segment_count(None, True, 8) == 1 + # Tiny file -> fewer segments than configured + assert effective_segment_count(1024, True, 8) == 1 + # Large file with range support -> configured count + assert effective_segment_count(1_000_000_000, True, 8) == 8 + + +# ----- structural ----- + + +def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes: + header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}} + header_bytes = json.dumps(header).encode("utf-8") + body = b"\x00" * tensor_data_len + if corrupt_size: + body = body[:-1] # truncate one byte + return struct.pack(" allowed + ("https://civitai.com/x/model.safetensors", True), + # no extension in the path (Civitai download API) -> allowed, resolved later + ("https://civitai.com/api/download/models/3031464?fileId=2910346", True), + ("https://civitai.com/api/download/models/3031464", True), + # explicit non-model extension -> rejected even on an allowed host + ("https://civitai.com/api/download/models/thing.zip", False), + ("https://huggingface.co/org/repo/resolve/main/config.json", False), + # off-list host is never downloadable + ("https://evil.example.com/api/download/models/1", False), + # http to a non-loopback allowlisted host is not permitted + ("http://civitai.com/api/download/models/1", False), + ], +) +def test_is_url_downloadable(url, downloadable): + assert allowlist.is_url_downloadable(url) is downloadable + + +@pytest.mark.parametrize( + "name,ext", + [ + ("model.safetensors", ".safetensors"), + ("model.SAFETENSORS", ".safetensors"), + ("archive.tar.gz", ".gz"), + ("noext", ""), + (".safetensors", ""), # leading-dot dotfile -> no extension + ("a/b/c/model.ckpt", ".ckpt"), + ], +) +def test_filename_extension(name, ext): + assert allowlist.filename_extension(name) == ext + + +# ----- SSRF: blocked IPs ----- + + +@pytest.mark.parametrize( + "ip,blocked", + [ + ("169.254.169.254", True), # cloud metadata / link-local + ("127.0.0.1", True), + ("10.0.0.5", True), + ("192.168.1.1", True), + ("172.16.0.1", True), + ("::1", True), + ("0.0.0.0", True), + # IPv4-mapped IPv6: must see through the mapping even on CPython + # versions predating the gh-113171 is_* property fix. + ("::ffff:169.254.169.254", True), # mapped cloud metadata + ("::ffff:127.0.0.1", True), # mapped loopback + ("::ffff:10.0.0.1", True), # mapped RFC1918 + ("::ffff:8.8.8.8", False), # mapped public address stays allowed + ("8.8.8.8", False), + ("1.1.1.1", False), + ("not-an-ip", True), # unparseable -> refuse + ], +) +def test_is_blocked_ip(ip, blocked): + assert is_blocked_ip(ip) is blocked + + +# ----- SSRF: redirect hop validation ----- + + +def test_check_redirect_hop_rejects_bad_scheme_and_userinfo(): + with pytest.raises(SSRFError): + check_redirect_hop("ftp://huggingface.co/x.safetensors") + with pytest.raises(SSRFError): + check_redirect_hop("https://user:pass@cdn.example.com/x") + # A CDN host that is NOT on the allowlist is allowed as a redirect target + # (private-IP protection is the resolver's job; credential leak is prevented + # by exact host matching). + assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None + + +def test_check_redirect_hop_http_only_for_loopback(): + # Plain http to an external host is rejected (no plaintext downgrade). + with pytest.raises(SSRFError): + check_redirect_hop("http://cdn-lfs.huggingface.co/abc") + # http is honored for loopback only on the initial user-supplied URL (the + # "download a local model" feature). + assert ( + check_redirect_hop("http://localhost/x.safetensors", is_initial_url=True) + is not None + ) + assert ( + check_redirect_hop("http://127.0.0.1/x.safetensors", is_initial_url=True) + is not None + ) + + +def test_check_redirect_hop_blocks_loopback_and_ip_literals_on_redirect(): + # A redirect (is_initial_url=False, the default) must never reach loopback, + # whether by hostname or by IP literal, nor any other internal IP literal. + for target in ( + "http://localhost/x.safetensors", + "http://127.0.0.1/x.safetensors", + "https://[::1]/x.safetensors", + "https://169.254.169.254/x.safetensors", # cloud metadata + "https://10.0.0.5/x.safetensors", # RFC1918 + ): + with pytest.raises(SSRFError): + check_redirect_hop(target) + # Off-allowlist public CDN hosts (hostnames) remain valid redirect targets; + # their resolved IPs are screened by the connector's resolver. + assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None + + +# ----- path safety ----- + + +def test_parse_model_id_valid(model_root): + directory, filename = paths.parse_model_id("loras/my_lora.safetensors") + assert directory == "loras" + assert filename == "my_lora.safetensors" + + +@pytest.mark.parametrize( + "model_id", + [ + "loras/../etc/passwd.safetensors", # traversal + "loras/sub/dir.safetensors", # nested + "unknownfolder/x.safetensors", # unknown folder + "loras/model.txt", # bad extension + "noslash.safetensors", # missing directory + "loras/", # empty filename + ], +) +def test_parse_model_id_rejects(model_root, model_id): + with pytest.raises(paths.InvalidModelId): + paths.parse_model_id(model_id) + + +def test_resolve_destination_stays_in_root(model_root): + final_path, temp_path = paths.resolve_destination("loras/x.safetensors") + assert final_path.startswith(model_root) + assert temp_path.startswith(model_root) + assert temp_path != final_path + + +@pytest.mark.parametrize( + "model_id,ext,expected", + [ + # no extension -> append the resolved one + ("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"), + # different known extension -> replace it + ("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"), + # same extension -> unchanged + ("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"), + # non-model suffix is treated as a stem, extension appended + ("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"), + # malformed (no slash) is returned untouched for parse_model_id to reject + ("noslash", ".safetensors", "noslash"), + ], +) +def test_apply_extension(model_id, ext, expected): + assert paths.apply_extension(model_id, ext) == expected + + +# ----- Content-Disposition filename parsing ----- + + +@pytest.mark.parametrize( + "header,expected", + [ + ('attachment; filename="model.safetensors"', "model.safetensors"), + ("attachment; filename=model.ckpt", "model.ckpt"), + # RFC 5987 form is preferred and percent-decoded + ( + "attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors", + "my model.safetensors", + ), + # directory components in a hostile header are stripped to the basename + ('attachment; filename="../../etc/passwd"', "passwd"), + ('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"), + ("inline", None), + (None, None), + ], +) +def test_filename_from_content_disposition(header, expected): + from app.model_downloader.net.http import filename_from_content_disposition + + assert filename_from_content_disposition(header) == expected