From c7c18377a32e345d0b3e520b82eb811ec9dea113 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Sat, 27 Jun 2026 13:10:05 +0200 Subject: [PATCH] Add initial commit for model downloader. --- alembic_db/versions/0005_download_manager.py | 118 +++++ app/database/db.py | 1 + app/model_downloader/api/routes.py | 203 ++++++++ app/model_downloader/api/schemas_in.py | 51 ++ app/model_downloader/api/schemas_out.py | 26 + app/model_downloader/constants.py | 38 ++ app/model_downloader/credentials/resolver.py | 99 ++++ app/model_downloader/credentials/store.py | 137 +++++ app/model_downloader/database/models.py | 162 ++++++ app/model_downloader/database/queries.py | 235 +++++++++ app/model_downloader/engine/job.py | 443 ++++++++++++++++ app/model_downloader/engine/planner.py | 51 ++ app/model_downloader/engine/writer.py | 61 +++ app/model_downloader/manager.py | 294 +++++++++++ app/model_downloader/net/http.py | 110 ++++ app/model_downloader/net/probe.py | 90 ++++ app/model_downloader/net/session.py | 72 +++ app/model_downloader/scheduler.py | 160 ++++++ app/model_downloader/security/allowlist.py | 84 +++ app/model_downloader/security/paths.py | 110 ++++ app/model_downloader/security/ssrf.py | 111 ++++ app/model_downloader/verify/checksum.py | 49 ++ app/model_downloader/verify/dedup.py | 53 ++ app/model_downloader/verify/structural.py | 65 +++ comfy/cli_args.py | 8 + comfy_api/feature_flags.py | 1 + openapi.yaml | 483 ++++++++++++++++++ server.py | 21 + tests-unit/model_downloader_test/conftest.py | 63 +++ .../model_downloader_test/test_credentials.py | 110 ++++ .../test_engine_integration.py | 270 ++++++++++ .../test_planner_structural.py | 71 +++ .../model_downloader_test/test_security.py | 111 ++++ 33 files changed, 3961 insertions(+) create mode 100644 alembic_db/versions/0005_download_manager.py create mode 100644 app/model_downloader/api/routes.py create mode 100644 app/model_downloader/api/schemas_in.py create mode 100644 app/model_downloader/api/schemas_out.py create mode 100644 app/model_downloader/constants.py create mode 100644 app/model_downloader/credentials/resolver.py create mode 100644 app/model_downloader/credentials/store.py create mode 100644 app/model_downloader/database/models.py create mode 100644 app/model_downloader/database/queries.py create mode 100644 app/model_downloader/engine/job.py create mode 100644 app/model_downloader/engine/planner.py create mode 100644 app/model_downloader/engine/writer.py create mode 100644 app/model_downloader/manager.py create mode 100644 app/model_downloader/net/http.py create mode 100644 app/model_downloader/net/probe.py create mode 100644 app/model_downloader/net/session.py create mode 100644 app/model_downloader/scheduler.py create mode 100644 app/model_downloader/security/allowlist.py create mode 100644 app/model_downloader/security/paths.py create mode 100644 app/model_downloader/security/ssrf.py create mode 100644 app/model_downloader/verify/checksum.py create mode 100644 app/model_downloader/verify/dedup.py create mode 100644 app/model_downloader/verify/structural.py create mode 100644 tests-unit/model_downloader_test/conftest.py create mode 100644 tests-unit/model_downloader_test/test_credentials.py create mode 100644 tests-unit/model_downloader_test/test_engine_integration.py create mode 100644 tests-unit/model_downloader_test/test_planner_structural.py create mode 100644 tests-unit/model_downloader_test/test_security.py diff --git a/alembic_db/versions/0005_download_manager.py b/alembic_db/versions/0005_download_manager.py new file mode 100644 index 000000000..e2f879e75 --- /dev/null +++ b/alembic_db/versions/0005_download_manager.py @@ -0,0 +1,118 @@ +""" +Download manager schema. + +Adds the three tables that back the server-side model download manager +(PRD section 7): transient job/queue state (``downloads`` + per-segment +``download_segments``) and one-API-key-per-host auth (``host_credentials``). + +The local file catalog / dedup index is intentionally NOT added here — it +is owned by the assets system (``assets`` / ``asset_references``). + +Revision ID: 0005_download_manager +Revises: 0004_drop_tag_type +Create Date: 2026-06-27 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0005_download_manager" +down_revision = "0004_drop_tag_type" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "downloads", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("url", sa.Text(), nullable=False), + sa.Column("final_url", sa.Text(), nullable=True), + sa.Column("model_id", sa.String(length=1024), nullable=False), + sa.Column("dest_path", sa.Text(), nullable=False), + sa.Column("temp_path", sa.Text(), nullable=False), + sa.Column("status", sa.String(length=16), nullable=False), + sa.Column("priority", sa.Integer(), nullable=False, server_default="0"), + sa.Column("total_bytes", sa.BigInteger(), nullable=True), + sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("etag", sa.String(length=512), nullable=True), + sa.Column("last_modified", sa.String(length=128), nullable=True), + sa.Column( + "accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("expected_sha256", sa.String(length=64), nullable=True), + sa.Column("credential_id", sa.String(length=36), nullable=True), + sa.Column( + "allow_any_extension", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"), + sa.CheckConstraint( + "total_bytes IS NULL OR total_bytes >= 0", + name="ck_downloads_total_bytes_nonneg", + ), + ) + op.create_index("ix_downloads_status", "downloads", ["status"]) + op.create_index("ix_downloads_priority", "downloads", ["priority"]) + op.create_index("ix_downloads_model_id", "downloads", ["model_id"]) + + op.create_table( + "download_segments", + sa.Column( + "download_id", + sa.String(length=36), + sa.ForeignKey("downloads.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("idx", sa.Integer(), nullable=False), + sa.Column("start_offset", sa.BigInteger(), nullable=False), + sa.Column("end_offset", sa.BigInteger(), nullable=False), + sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"), + sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"), + sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"), + sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"), + ) + + op.create_table( + "host_credentials", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("host", sa.String(length=255), nullable=False), + sa.Column( + "match_subdomains", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column("label", sa.String(length=255), nullable=True), + sa.Column( + "auth_scheme", sa.String(length=16), nullable=False, server_default="bearer" + ), + sa.Column("header_name", sa.String(length=255), nullable=True), + sa.Column("query_param", sa.String(length=255), nullable=True), + sa.Column("secret", sa.Text(), nullable=False), + sa.Column("secret_last4", sa.String(length=4), nullable=True), + sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + ) + op.create_index( + "uq_host_credentials_host", "host_credentials", ["host"], unique=True + ) + + +def downgrade() -> None: + op.drop_index("uq_host_credentials_host", table_name="host_credentials") + op.drop_table("host_credentials") + + op.drop_table("download_segments") + + op.drop_index("ix_downloads_model_id", table_name="downloads") + op.drop_index("ix_downloads_priority", table_name="downloads") + op.drop_index("ix_downloads_status", table_name="downloads") + op.drop_table("downloads") diff --git a/app/database/db.py b/app/database/db.py index 0aab09a49..8d776344e 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -21,6 +21,7 @@ try: from app.database.models import Base import app.assets.database.models # noqa: F401 — register models with Base.metadata + import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: diff --git a/app/model_downloader/api/routes.py b/app/model_downloader/api/routes.py new file mode 100644 index 000000000..2c0ed5ae4 --- /dev/null +++ b/app/model_downloader/api/routes.py @@ -0,0 +1,203 @@ +"""aiohttp routes for the download manager. + +Endpoint surface (all under ``/api/download``), mirroring the response +envelope used by ``app/assets/api/routes.py``: + + POST /api/download/enqueue + GET /api/download + POST /api/download/availability + POST /api/download/credentials + GET /api/download/credentials + GET /api/download/credentials/{id} + DELETE /api/download/credentials/{id} + GET /api/download/{id} + POST /api/download/{id}/pause + POST /api/download/{id}/resume + POST /api/download/{id}/cancel + POST /api/download/{id}/priority + +Note on ordering: the static ``credentials`` routes are registered before the +dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not +captured as ``id == "credentials"``. +""" + +from __future__ import annotations + +import json + +from aiohttp import web +from pydantic import BaseModel, ValidationError + +from app.model_downloader.api import schemas_in, schemas_out +from app.model_downloader.credentials.store import ( + CREDENTIAL_STORE, + CredentialValidationError, +) +from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + +ROUTES = web.RouteTableDef() + + +def register_routes(app: web.Application) -> None: + """Wire the download-manager routes into the running aiohttp app.""" + app.add_routes(ROUTES) + + +# ----- envelope helpers (same shape as app/assets/api/routes.py) ----- + + +def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _ok(payload, status: int = 200) -> web.Response: + return web.json_response(payload, status=status) + + +async def _parse(request: web.Request, model: type[BaseModel]): + try: + raw = await request.json() + except json.JSONDecodeError: + return _error(400, "INVALID_JSON", "Request body must be valid JSON.") + try: + return model.model_validate(raw) + except ValidationError as ve: + return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())}) + + +def _from_download_error(e: DownloadError) -> web.Response: + return _error(e.http_status, e.code, e.message) + + +# ----- downloads: collection + enqueue + availability ----- + + +@ROUTES.post("/api/download/enqueue") +async def enqueue(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.EnqueueRequest) + if isinstance(parsed, web.Response): + return parsed + try: + download_id = await DOWNLOAD_MANAGER.enqueue( + parsed.url, + parsed.model_id, + priority=parsed.priority, + expected_sha256=parsed.expected_sha256, + allow_any_extension=parsed.allow_any_extension, + credential_id=parsed.credential_id, + ) + except DownloadError as e: + return _from_download_error(e) + return _ok({"download_id": download_id, "accepted": True}, status=202) + + +@ROUTES.get("/api/download") +async def list_downloads(request: web.Request) -> web.Response: + return _ok({"downloads": await DOWNLOAD_MANAGER.list()}) + + +@ROUTES.post("/api/download/availability") +async def availability(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.AvailabilityRequest) + if isinstance(parsed, web.Response): + return parsed + return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)}) + + +# ----- credentials (secrets are write-only) — must precede /{id} ----- + + +@ROUTES.post("/api/download/credentials") +async def upsert_credential(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.CredentialUpsertRequest) + if isinstance(parsed, web.Response): + return parsed + try: + view = await CREDENTIAL_STORE.upsert( + parsed.host, + parsed.secret, + auth_scheme=parsed.auth_scheme, + header_name=parsed.header_name, + query_param=parsed.query_param, + label=parsed.label, + match_subdomains=parsed.match_subdomains, + enabled=parsed.enabled, + ) + except CredentialValidationError as e: + return _error(400, "INVALID_CREDENTIAL", str(e)) + return _ok(schemas_out.credential_to_dict(view), status=201) + + +@ROUTES.get("/api/download/credentials") +async def list_credentials(request: web.Request) -> web.Response: + views = await CREDENTIAL_STORE.list() + return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]}) + + +@ROUTES.get("/api/download/credentials/{id}") +async def get_credential(request: web.Request) -> web.Response: + view = await CREDENTIAL_STORE.get(request.match_info["id"]) + if view is None: + return _error(404, "NOT_FOUND", "No such credential.") + return _ok(schemas_out.credential_to_dict(view)) + + +@ROUTES.delete("/api/download/credentials/{id}") +async def delete_credential(request: web.Request) -> web.Response: + deleted = await CREDENTIAL_STORE.delete(request.match_info["id"]) + if not deleted: + return _error(404, "NOT_FOUND", "No such credential.") + return _ok({"deleted": True}) + + +# ----- single download by id (dynamic; registered last) ----- + + +@ROUTES.get("/api/download/{id}") +async def get_download(request: web.Request) -> web.Response: + view = await DOWNLOAD_MANAGER.status(request.match_info["id"]) + if view is None: + return _error(404, "NOT_FOUND", "No such download.") + return _ok(view) + + +@ROUTES.post("/api/download/{id}/pause") +async def pause(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.pause(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) + + +@ROUTES.post("/api/download/{id}/resume") +async def resume(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.resume(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) + + +@ROUTES.post("/api/download/{id}/cancel") +async def cancel(request: web.Request) -> web.Response: + try: + await DOWNLOAD_MANAGER.cancel(request.match_info["id"]) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) + + +@ROUTES.post("/api/download/{id}/priority") +async def set_priority(request: web.Request) -> web.Response: + parsed = await _parse(request, schemas_in.PriorityRequest) + if isinstance(parsed, web.Response): + return parsed + try: + await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority) + except DownloadError as e: + return _from_download_error(e) + return _ok({"ok": True}) diff --git a/app/model_downloader/api/schemas_in.py b/app/model_downloader/api/schemas_in.py new file mode 100644 index 000000000..c2db2feb4 --- /dev/null +++ b/app/model_downloader/api/schemas_in.py @@ -0,0 +1,51 @@ +"""Request schemas for the download manager API. + +Pydantic enforces shape at the boundary; handlers operate only on validated +values past that point. +""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Field + +from app.model_downloader.constants import AUTH_SCHEME_BEARER + + +class EnqueueRequest(BaseModel): + url: str + model_id: str + priority: int = 0 + expected_sha256: Optional[str] = None + allow_any_extension: bool = False + credential_id: Optional[str] = None + + +class PriorityRequest(BaseModel): + priority: int + + +class AvailabilityRequest(BaseModel): + """``{model_id: url}`` — the URLs declared in the workflow JSON.""" + + models: dict[str, str] = Field(default_factory=dict) + + +class CredentialUpsertRequest(BaseModel): + host: str + secret: str + auth_scheme: str = AUTH_SCHEME_BEARER + header_name: Optional[str] = None + query_param: Optional[str] = None + label: Optional[str] = None + match_subdomains: bool = False + enabled: bool = True + + +__all__ = [ + "EnqueueRequest", + "PriorityRequest", + "AvailabilityRequest", + "CredentialUpsertRequest", +] diff --git a/app/model_downloader/api/schemas_out.py b/app/model_downloader/api/schemas_out.py new file mode 100644 index 000000000..59ace6430 --- /dev/null +++ b/app/model_downloader/api/schemas_out.py @@ -0,0 +1,26 @@ +"""Response helpers for the download manager API. + +The download/status read models are plain dicts produced by the manager. This +module only needs to mask credentials for output (the secret is never returned). +""" + +from __future__ import annotations + +from app.model_downloader.credentials.store import CredentialView + + +def credential_to_dict(view: CredentialView) -> dict: + """API-safe credential representation — never includes the secret.""" + return { + "id": view.id, + "host": view.host, + "auth_scheme": view.auth_scheme, + "header_name": view.header_name, + "query_param": view.query_param, + "label": view.label, + "match_subdomains": view.match_subdomains, + "enabled": view.enabled, + "secret_last4": view.secret_last4, + "created_at": view.created_at, + "updated_at": view.updated_at, + } diff --git a/app/model_downloader/constants.py b/app/model_downloader/constants.py new file mode 100644 index 000000000..0692a7430 --- /dev/null +++ b/app/model_downloader/constants.py @@ -0,0 +1,38 @@ +"""Shared constants for the download manager. + +Status values are persisted as TEXT in the ``downloads`` table; keep them +stable. The lifecycle is (PRD section 6): + + queued -> active -> verifying -> completed + | |-> paused -> (resume) -> active + | |-> failed (network, retryable) -> queued (backoff) + |-> cancelled +""" + +from __future__ import annotations + +# Auth schemes for HostCredential (PRD section 9.4.1). +AUTH_SCHEME_BEARER = "bearer" +AUTH_SCHEME_HEADER = "header" +AUTH_SCHEME_QUERY = "query" +AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY) + + +class DownloadStatus: + QUEUED = "queued" + ACTIVE = "active" + PAUSED = "paused" + VERIFYING = "verifying" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + #: States from which a worker is doing (or about to do) network I/O. + LIVE = (QUEUED, ACTIVE, VERIFYING) + #: Terminal states — the job will not transition again on its own. + TERMINAL = (COMPLETED, FAILED, CANCELLED) + + +# Default temp-file suffix. Distinctive so the startup orphan sweep only +# removes files THIS subsystem created, never unrelated *.tmp files. +TMP_SUFFIX = ".comfy-download.part" diff --git a/app/model_downloader/credentials/resolver.py b/app/model_downloader/credentials/resolver.py new file mode 100644 index 000000000..f6fa50c92 --- /dev/null +++ b/app/model_downloader/credentials/resolver.py @@ -0,0 +1,99 @@ +"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2). + +The critical rule: a credential is only ever attached when *the current hop's +host* matches a stored credential, and only over https. This is recomputed +from scratch on every redirect hop, so a token bound to ``huggingface.co`` is +silently dropped when the request is redirected to a presigned CDN host — +which is exactly what these hubs expect. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Optional +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit + +from app.model_downloader.constants import ( + AUTH_SCHEME_BEARER, + AUTH_SCHEME_HEADER, + AUTH_SCHEME_QUERY, +) +from app.model_downloader.credentials.store import normalize_host +from app.model_downloader.database import queries +from app.model_downloader.database.models import HostCredential + + +@dataclass +class RequestAuth: + """How to modify a single request to carry a credential.""" + + headers: dict[str, str] = field(default_factory=dict) + query: dict[str, str] = field(default_factory=dict) + + def apply_to_url(self, url: str) -> str: + if not self.query: + return url + parts = urlsplit(url) + params = dict(parse_qsl(parts.query, keep_blank_values=True)) + params.update(self.query) + return urlunsplit(parts._replace(query=urlencode(params))) + + +def _matches(cred: HostCredential, hop_host: str) -> bool: + cred_host = cred.host + if hop_host == cred_host: + return True + if cred.match_subdomains: + # Label-boundary suffix: api.example.com matches example.com, but + # evil-example.com does NOT. + return hop_host.endswith("." + cred_host) + return False + + +def _build_auth(cred: HostCredential) -> RequestAuth: + if cred.auth_scheme == AUTH_SCHEME_BEARER: + return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"}) + if cred.auth_scheme == AUTH_SCHEME_HEADER: + name = cred.header_name or "Authorization" + return RequestAuth(headers={name: cred.secret}) + if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param: + return RequestAuth(query={cred.query_param: cred.secret}) + return RequestAuth() + + +def _resolve_sync( + host: str, scheme: str, explicit_credential_id: Optional[str] +) -> Optional[RequestAuth]: + # Never attach a secret over a non-https hop (PRD section 9.4.2). + if scheme.lower() != "https": + return None + hop_host = normalize_host(host) + if not hop_host: + return None + + if explicit_credential_id is not None: + cred = queries.get_credential(explicit_credential_id) + # An explicit credential is still subject to the per-hop host check — + # it is not forced onto a non-matching host. + if cred is None or not cred.enabled or not _matches(cred, hop_host): + return None + return _build_auth(cred) + + # Auto-resolve: exact host first, then any subdomain-matching credential. + cred = queries.get_credential_by_host(hop_host) + if cred is not None and cred.enabled: + return _build_auth(cred) + for sub in queries.list_subdomain_credentials(): + if sub.enabled and _matches(sub, hop_host): + return _build_auth(sub) + return None + + +async def resolve_auth_for_hop( + host: str, scheme: str, *, explicit_credential_id: Optional[str] = None +) -> Optional[RequestAuth]: + """Resolve the credential (if any) to attach for one request hop.""" + return await asyncio.to_thread( + _resolve_sync, host, scheme, explicit_credential_id + ) diff --git a/app/model_downloader/credentials/store.py b/app/model_downloader/credentials/store.py new file mode 100644 index 000000000..d1c979f05 --- /dev/null +++ b/app/model_downloader/credentials/store.py @@ -0,0 +1,137 @@ +"""The credential store: one API key per host (PRD section 9.4). + +Secrets are write-only over the API — :class:`CredentialView` carries only +masked metadata (``secret_last4`` + scheme + label), never the secret itself. +At-rest protection for v1 is filesystem permissions on the shared DB (the DB +is the trust boundary); encryption-at-rest is a noted future seam. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Optional + +from app.model_downloader.constants import ( + AUTH_SCHEME_BEARER, + AUTH_SCHEME_HEADER, + AUTH_SCHEME_QUERY, + AUTH_SCHEMES, +) +from app.model_downloader.database import queries +from app.model_downloader.database.models import HostCredential + + +def normalize_host(host: str) -> str: + """Lowercase, strip port, IDNA-encode (PRD section 9.4.3).""" + if not host: + return "" + host = host.strip().lower() + if host.startswith("[") and "]" in host: # bracketed IPv6 literal + host = host[1 : host.index("]")] + elif host.count(":") == 1: # host:port (not IPv6) + host = host.split(":", 1)[0] + try: + host = host.encode("idna").decode("ascii") + except (UnicodeError, ValueError): + pass + return host + + +@dataclass(frozen=True) +class CredentialView: + """Masked, API-safe view of a credential — never includes the secret.""" + + id: str + host: str + auth_scheme: str + header_name: Optional[str] + query_param: Optional[str] + label: Optional[str] + match_subdomains: bool + enabled: bool + secret_last4: Optional[str] + created_at: int + updated_at: int + + +def _to_view(row: HostCredential) -> CredentialView: + return CredentialView( + id=row.id, + host=row.host, + auth_scheme=row.auth_scheme, + header_name=row.header_name, + query_param=row.query_param, + label=row.label, + match_subdomains=row.match_subdomains, + enabled=row.enabled, + secret_last4=row.secret_last4, + created_at=row.created_at, + updated_at=row.updated_at, + ) + + +class CredentialValidationError(ValueError): + """A credential upsert had inconsistent fields.""" + + +class CredentialStore: + """Async facade over the ``host_credentials`` table. + + DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``. + """ + + async def upsert( + self, + host: str, + secret: str, + *, + auth_scheme: str = AUTH_SCHEME_BEARER, + header_name: Optional[str] = None, + query_param: Optional[str] = None, + label: Optional[str] = None, + match_subdomains: bool = False, + enabled: bool = True, + ) -> CredentialView: + host = normalize_host(host) + if not host: + raise CredentialValidationError("host is required") + if not secret: + raise CredentialValidationError("secret is required") + if auth_scheme not in AUTH_SCHEMES: + raise CredentialValidationError( + f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}" + ) + if auth_scheme == AUTH_SCHEME_HEADER and not header_name: + header_name = "Authorization" + if auth_scheme == AUTH_SCHEME_QUERY and not query_param: + raise CredentialValidationError( + "query_param is required when auth_scheme='query'" + ) + values = { + "host": host, + "secret": secret, + "secret_last4": secret[-4:] if len(secret) >= 4 else secret, + "auth_scheme": auth_scheme, + "header_name": header_name, + "query_param": query_param, + "label": label, + "match_subdomains": match_subdomains, + "enabled": enabled, + } + row = await asyncio.to_thread(queries.upsert_credential, values) + return _to_view(row) + + async def list(self) -> list[CredentialView]: + rows = await asyncio.to_thread(queries.list_credentials) + return [_to_view(r) for r in rows] + + async def get(self, credential_id: str) -> Optional[CredentialView]: + row = await asyncio.to_thread(queries.get_credential, credential_id) + return _to_view(row) if row is not None else None + + async def delete(self, credential_id: str) -> bool: + return await asyncio.to_thread(queries.delete_credential, credential_id) + + +CREDENTIAL_STORE = CredentialStore() diff --git a/app/model_downloader/database/models.py b/app/model_downloader/database/models.py new file mode 100644 index 000000000..cc2afda58 --- /dev/null +++ b/app/model_downloader/database/models.py @@ -0,0 +1,162 @@ +"""SQLAlchemy models for the download manager. + +Three tables (PRD section 7): + +- ``downloads`` one row per requested file (job + queue state). +- ``download_segments`` per-segment byte progress, for segmented resume. +- ``host_credentials`` one API key per host, reused across downloads. + +The local file catalog / dedup index is NOT here — that is owned by the +assets system (``assets`` / ``asset_references``). On completion a finished +file is registered into the assets catalog; ``downloads`` is kept only as +job history. +""" + +from __future__ import annotations + +import time +import uuid + +from sqlalchemy import ( + BigInteger, + Boolean, + CheckConstraint, + ForeignKey, + Index, + Integer, + String, + Text, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database.models import Base + + +def _uuid() -> str: + return str(uuid.uuid4()) + + +def _now() -> int: + return int(time.time()) + + +class Download(Base): + __tablename__ = "downloads" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + # Original requested URL and the final URL after validated redirects. + url: Mapped[str] = mapped_column(Text, nullable=False) + final_url: Mapped[str | None] = mapped_column(Text, nullable=True) + # Canonical "/" identifier (resolved via folder_paths). + model_id: Mapped[str] = mapped_column(String(1024), nullable=False) + # Final on-disk location and the .part write target. + dest_path: Mapped[str] = mapped_column(Text, nullable=False) + temp_path: Mapped[str] = mapped_column(Text, nullable=False) + + status: Mapped[str] = mapped_column(String(16), nullable=False) + priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + + etag: Mapped[str | None] = mapped_column(String(512), nullable=True) + last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + # Optional hub-provided checksum to verify against (NOT the dedup key). + expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True) + + # Explicit credential override; otherwise auto-resolved by host. + credential_id: Mapped[str | None] = mapped_column(String(36), nullable=True) + allow_any_extension: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + # How many retryable failures we have seen (for backoff capping). + attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + error: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now) + updated_at: Mapped[int] = mapped_column( + BigInteger, nullable=False, default=_now, onupdate=_now + ) + + segments: Mapped[list[DownloadSegment]] = relationship( + "DownloadSegment", + back_populates="download", + cascade="all,delete-orphan", + passive_deletes=True, + order_by="DownloadSegment.idx", + ) + + __table_args__ = ( + Index("ix_downloads_status", "status"), + Index("ix_downloads_priority", "priority"), + Index("ix_downloads_model_id", "model_id"), + CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"), + CheckConstraint( + "total_bytes IS NULL OR total_bytes >= 0", + name="ck_downloads_total_bytes_nonneg", + ), + ) + + def __repr__(self) -> str: + return f"" + + +class DownloadSegment(Base): + __tablename__ = "download_segments" + + download_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("downloads.id", ondelete="CASCADE"), + primary_key=True, + ) + idx: Mapped[int] = mapped_column(Integer, primary_key=True) + start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False) + end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False) + bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + + download: Mapped[Download] = relationship("Download", back_populates="segments") + + __table_args__ = ( + CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"), + CheckConstraint("end_offset >= start_offset", name="ck_segments_range"), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class HostCredential(Base): + __tablename__ = "host_credentials" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + # Normalized lowercase hostname, e.g. "civitai.com". + host: Mapped[str] = mapped_column(String(255), nullable=False) + match_subdomains: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + label: Mapped[str | None] = mapped_column(String(255), nullable=True) + auth_scheme: Mapped[str] = mapped_column( + String(16), nullable=False, default="bearer" + ) + header_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + query_param: Mapped[str | None] = mapped_column(String(255), nullable=True) + # The API key itself. Write-only over the API; never returned. See PRD 9.4.4. + secret: Mapped[str] = mapped_column(Text, nullable=False) + secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True) + enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now) + updated_at: Mapped[int] = mapped_column( + BigInteger, nullable=False, default=_now, onupdate=_now + ) + + __table_args__ = ( + Index("uq_host_credentials_host", "host", unique=True), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/model_downloader/database/queries.py b/app/model_downloader/database/queries.py new file mode 100644 index 000000000..6b5cbebe0 --- /dev/null +++ b/app/model_downloader/database/queries.py @@ -0,0 +1,235 @@ +"""Synchronous DB access for the download manager. + +All functions open their own short-lived session via ``create_session`` and +commit before returning, mirroring ``app/assets`` usage. They are blocking +(SQLite) and should be called from async code through ``asyncio.to_thread``. +""" + +from __future__ import annotations + +import time +from typing import Optional + +from sqlalchemy import select + +from app.database.db import create_session +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database.models import ( + Download, + DownloadSegment, + HostCredential, +) + + +# ----- downloads ----- + + +def insert_download(values: dict) -> None: + with create_session() as session: + session.add(Download(**values)) + session.commit() + + +def get_download(download_id: str) -> Optional[Download]: + with create_session() as session: + row = session.get(Download, download_id) + if row is not None: + session.expunge_all() + return row + + +def list_downloads() -> list[Download]: + with create_session() as session: + rows = list( + session.execute( + select(Download).order_by(Download.created_at.desc()) + ).scalars() + ) + session.expunge_all() + return rows + + +def list_segments(download_id: str) -> list[DownloadSegment]: + with create_session() as session: + rows = list( + session.execute( + select(DownloadSegment) + .where(DownloadSegment.download_id == download_id) + .order_by(DownloadSegment.idx) + ).scalars() + ) + session.expunge_all() + return rows + + +def update_download(download_id: str, **fields) -> None: + if not fields: + return + fields.setdefault("updated_at", int(time.time())) + with create_session() as session: + row = session.get(Download, download_id) + if row is None: + return + for key, value in fields.items(): + setattr(row, key, value) + session.commit() + + +def delete_download(download_id: str) -> None: + with create_session() as session: + row = session.get(Download, download_id) + if row is not None: + session.delete(row) + session.commit() + + +def replace_segments(download_id: str, segments: list[dict]) -> None: + """Atomically replace the segment plan for a download.""" + with create_session() as session: + session.query(DownloadSegment).filter( + DownloadSegment.download_id == download_id + ).delete() + for seg in segments: + session.add(DownloadSegment(download_id=download_id, **seg)) + session.commit() + + +def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None: + with create_session() as session: + row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx}) + if row is None: + return + row.bytes_done = bytes_done + session.commit() + + +def list_queued_downloads() -> list[Download]: + """Queued rows ordered for admission (priority desc, then FIFO).""" + with create_session() as session: + rows = list( + session.execute( + select(Download) + .where(Download.status == DownloadStatus.QUEUED) + .order_by(Download.priority.desc(), Download.created_at.asc()) + ).scalars() + ) + session.expunge_all() + return rows + + +def reconcile_live_downloads() -> list[Download]: + """Reset any ``active``/``verifying`` rows left by a previous run. + + On a clean restart there can be no live worker, so anything still marked + live is stale. Move it back to ``queued`` (offsets are preserved on the + segment rows) so the scheduler re-admits it. Returns the rows that should + be re-queued by the scheduler (queued + paused). + """ + with create_session() as session: + stale = list( + session.execute( + select(Download).where( + Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING]) + ) + ).scalars() + ) + now = int(time.time()) + for row in stale: + row.status = DownloadStatus.QUEUED + row.updated_at = now + session.commit() + + resumable = list( + session.execute( + select(Download) + .where(Download.status == DownloadStatus.QUEUED) + .order_by(Download.priority.desc(), Download.created_at.asc()) + ).scalars() + ) + session.expunge_all() + return resumable + + +# ----- host credentials ----- + + +def get_credential(credential_id: str) -> Optional[HostCredential]: + with create_session() as session: + row = session.get(HostCredential, credential_id) + if row is not None: + session.expunge_all() + return row + + +def get_credential_by_host(host: str) -> Optional[HostCredential]: + with create_session() as session: + row = ( + session.execute( + select(HostCredential).where(HostCredential.host == host).limit(1) + ) + .scalars() + .first() + ) + if row is not None: + session.expunge_all() + return row + + +def list_credentials() -> list[HostCredential]: + with create_session() as session: + rows = list( + session.execute( + select(HostCredential).order_by(HostCredential.host) + ).scalars() + ) + session.expunge_all() + return rows + + +def list_subdomain_credentials() -> list[HostCredential]: + """Credentials that opted into subdomain matching, for suffix checks.""" + with create_session() as session: + rows = list( + session.execute( + select(HostCredential).where(HostCredential.match_subdomains.is_(True)) + ).scalars() + ) + session.expunge_all() + return rows + + +def upsert_credential(values: dict) -> HostCredential: + """Insert or update a credential keyed by ``host``.""" + host = values["host"] + now = int(time.time()) + with create_session() as session: + row = ( + session.execute( + select(HostCredential).where(HostCredential.host == host).limit(1) + ) + .scalars() + .first() + ) + if row is None: + row = HostCredential(**values) + row.created_at = now + row.updated_at = now + session.add(row) + else: + for key, value in values.items(): + setattr(row, key, value) + row.updated_at = now + session.commit() + session.refresh(row) + session.expunge(row) + return row + + +def delete_credential(credential_id: str) -> bool: + with create_session() as session: + row = session.get(HostCredential, credential_id) + if row is None: + return False + session.delete(row) + session.commit() + return True diff --git a/app/model_downloader/engine/job.py b/app/model_downloader/engine/job.py new file mode 100644 index 000000000..461478eae --- /dev/null +++ b/app/model_downloader/engine/job.py @@ -0,0 +1,443 @@ +"""The per-download worker (PRD sections 5, 6, 8, 12). + +One :class:`DownloadJob` drives a single file from probe to verified, cataloged +completion. It supports cooperative pause / resume / cancel, segmented +multi-connection transfer with positioned writes, and a verification gate +(size + structural + optional sha256) before the atomic rename into place. + +Control is cooperative: external callers flip ``_control`` via +:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between +chunks and raise, which unwinds cleanly and persists resume offsets. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Callable, Optional + +from comfy.cli_args import args +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.engine.planner import ( + SegmentPlan, + effective_segment_count, + plan_segments, +) +from app.model_downloader.engine.writer import FileWriter +from app.model_downloader.net.http import open_validated +from app.model_downloader.net.probe import probe +from app.model_downloader.verify import checksum, dedup, structural + +_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504} +_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists + + +class Paused(Exception): + pass + + +class Cancelled(Exception): + pass + + +class RemoteChanged(Exception): + """The remote file changed under a resume (got 200 where 206 expected).""" + + +class RetryableError(Exception): + pass + + +class FatalError(Exception): + """Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc.""" + + +@dataclass +class SegmentRuntime: + idx: int + start: int + end: int # inclusive; may be -1 for unknown-size single stream + bytes_done: int = 0 + + @property + def length(self) -> int: + return self.end - self.start + 1 + + +@dataclass +class RuntimeState: + download_id: str + model_id: str + url: str + priority: int + status: str + total_bytes: Optional[int] = None + bytes_done: int = 0 + error: Optional[str] = None + segments: list[SegmentRuntime] = field(default_factory=list) + started_at: float = field(default_factory=time.monotonic) + _last_bytes: int = 0 + _last_time: float = field(default_factory=time.monotonic) + speed_bps: float = 0.0 + + @property + def progress(self) -> Optional[float]: + if not self.total_bytes: + return None + return min(1.0, self.bytes_done / self.total_bytes) + + @property + def eta_seconds(self) -> Optional[float]: + if not self.total_bytes or self.speed_bps <= 0: + return None + remaining = max(0, self.total_bytes - self.bytes_done) + return remaining / self.speed_bps + + +@dataclass +class JobSpec: + download_id: str + url: str + model_id: str + dest_path: str + temp_path: str + priority: int = 0 + credential_id: Optional[str] = None + expected_sha256: Optional[str] = None + allow_any_extension: bool = False + etag: Optional[str] = None + attempts: int = 0 + + +class DownloadJob: + def __init__( + self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None + ) -> None: + self.spec = spec + self._notify = notify_cb + self._control = "run" # run | pause | cancel + self.state = RuntimeState( + download_id=spec.download_id, + model_id=spec.model_id, + url=spec.url, + priority=spec.priority, + status=DownloadStatus.QUEUED, + ) + self._writer: Optional[FileWriter] = None + self._etag: Optional[str] = spec.etag + self._last_persist = 0.0 + + # ----- external control ----- + + def request_pause(self) -> None: + if self._control == "run": + self._control = "pause" + + def request_cancel(self) -> None: + self._control = "cancel" + + def _check_control(self) -> None: + if self._control == "cancel": + raise Cancelled() + if self._control == "pause": + raise Paused() + + # ----- lifecycle ----- + + async def run(self) -> str: + """Run to a terminal/paused state; returns the final status string.""" + self._set_status(DownloadStatus.ACTIVE, error=None) + try: + pr = await self._probe_and_plan() + await self._transfer(pr) + await self._finalize() + self._set_status(DownloadStatus.COMPLETED) + except Paused: + await self._persist_progress(force=True) + self._set_status(DownloadStatus.PAUSED) + except Cancelled: + await self._close_writer() + self._remove_temp() + self._set_status(DownloadStatus.CANCELLED) + except RemoteChanged: + await self._reset_for_restart() + self._set_status( + DownloadStatus.QUEUED, error="remote file changed; restarting" + ) + except RetryableError as e: + await self._persist_progress(force=True) + self._set_status(DownloadStatus.QUEUED, error=str(e)) + except FatalError as e: + await self._close_writer() + self._remove_temp() + self._set_status(DownloadStatus.FAILED, error=str(e)) + except Exception as e: # unexpected -> treat as retryable + logging.warning( + "[model_downloader] %s unexpected error: %s", + self.spec.model_id, e, exc_info=True, + ) + await self._persist_progress(force=True) + self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}") + finally: + await self._close_writer() + return self.state.status + + # ----- probe + plan ----- + + async def _probe_and_plan(self): + pr = await probe(self.spec.url, credential_id=self.spec.credential_id) + if not pr.ok: + if pr.gated: + raise FatalError( + f"{self.spec.url} requires authentication. Add an API key for " + f"this host at /api/download/credentials and retry." + ) + if pr.status == 0 or pr.status in _RETRYABLE_STATUSES: + raise RetryableError(pr.error or "probe failed") + raise FatalError(pr.error or f"probe returned HTTP {pr.status}") + + self._etag = pr.etag or self._etag + self.state.total_bytes = pr.total_bytes + queries.update_download( + self.spec.download_id, + final_url=pr.final_url, + total_bytes=pr.total_bytes, + accept_ranges=pr.accept_ranges, + etag=pr.etag, + last_modified=pr.last_modified, + ) + + seg_count = effective_segment_count( + pr.total_bytes, pr.accept_ranges, max(1, args.download_segments) + ) + existing = queries.list_segments(self.spec.download_id) + if ( + seg_count > 1 + and existing + and pr.total_bytes is not None + and existing[-1].end_offset == pr.total_bytes - 1 + ): + # Resume an existing segmented plan. + self.state.segments = [ + SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done) + for s in existing + ] + elif seg_count > 1 and pr.total_bytes is not None: + plans = plan_segments(pr.total_bytes, seg_count) + queries.replace_segments( + self.spec.download_id, + [ + {"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0} + for p in plans + ], + ) + self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans] + else: + # Single-stream: one logical segment; bytes_done tracked on the row. + row = queries.get_download(self.spec.download_id) + resume_from = row.bytes_done if row else 0 + end = (pr.total_bytes - 1) if pr.total_bytes else -1 + self.state.segments = [SegmentRuntime(0, 0, end, resume_from)] + self._recompute_bytes_done() + return pr + + # ----- transfer ----- + + async def _transfer(self, pr) -> None: + self._writer = FileWriter(self.spec.temp_path) + await self._writer.open() + + segmented = len(self.state.segments) > 1 + if segmented and self.state.total_bytes: + await self._writer.preallocate(self.state.total_bytes) + await self._run_segmented() + else: + await self._run_single() + + await self._writer.flush() + + async def _run_segmented(self) -> None: + pending = [ + asyncio.ensure_future(self._run_segment(seg)) + for seg in self.state.segments + if seg.bytes_done < seg.length + ] + if not pending: + return + done, not_done = await asyncio.wait( + pending, return_when=asyncio.FIRST_EXCEPTION + ) + first_exc: Optional[BaseException] = None + for task in done: + exc = task.exception() + if exc is not None and first_exc is None: + first_exc = exc + if first_exc is not None: + for task in not_done: + task.cancel() + await asyncio.gather(*not_done, return_exceptions=True) + raise first_exc + + async def _run_segment(self, seg: SegmentRuntime) -> None: + offset = seg.start + seg.bytes_done + headers = { + "Range": f"bytes={offset}-{seg.end}", + "Accept-Encoding": "identity", + } + if self._etag: + headers["If-Range"] = self._etag + async with open_validated( + "GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers + ) as (resp, _final): + if resp.status == 200: + # Server ignored the range -> remote changed / no resume support. + raise RemoteChanged() + if resp.status not in (206,): + self._raise_for_status(resp.status) + async for chunk in resp.content.iter_chunked(args.download_chunk_size): + self._check_control() + await self._writer.write_at(offset, chunk) + offset += len(chunk) + seg.bytes_done += len(chunk) + self._recompute_bytes_done() + await self._persist_progress() + + async def _run_single(self) -> None: + seg = self.state.segments[0] + offset = seg.bytes_done # resume from here for single-stream + headers = {"Accept-Encoding": "identity"} + if offset > 0: + headers["Range"] = f"bytes={offset}-" + if self._etag: + headers["If-Range"] = self._etag + async with open_validated( + "GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers + ) as (resp, _final): + if offset > 0 and resp.status == 200: + # Resume not honoured -> start over from the beginning. + offset = 0 + seg.bytes_done = 0 + elif offset > 0 and resp.status != 206: + self._raise_for_status(resp.status) + elif offset == 0 and resp.status != 200: + self._raise_for_status(resp.status) + async for chunk in resp.content.iter_chunked(args.download_chunk_size): + self._check_control() + await self._writer.write_at(offset, chunk) + offset += len(chunk) + seg.bytes_done = offset + self.state.bytes_done = offset + await self._persist_progress() + + def _raise_for_status(self, status: int) -> None: + if status in (401, 403): + raise FatalError( + f"{self.spec.url} returned {status}; add/update an API key for " + f"this host at /api/download/credentials." + ) + if status in _RETRYABLE_STATUSES: + raise RetryableError(f"HTTP {status}") + raise FatalError(f"unexpected HTTP {status}") + + # ----- finalize / verify (PRD section 8.4) ----- + + async def _finalize(self) -> None: + self._check_control() + await self._close_writer() + self._set_status(DownloadStatus.VERIFYING) + + total = self.state.total_bytes + actual_size = os.path.getsize(self.spec.temp_path) + if total is not None and actual_size != total: + raise FatalError( + f"size mismatch: wrote {actual_size} of {total} bytes" + ) + + # Structural gate (cheap, no full read) then optional sha256 (full read). + await asyncio.to_thread(structural.validate, self.spec.temp_path) + if self.spec.expected_sha256: + await asyncio.to_thread( + checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256 + ) + + os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True) + os.replace(self.spec.temp_path, self.spec.dest_path) + logging.info( + "[model_downloader] completed %s (%d bytes)", + self.spec.model_id, actual_size, + ) + # Catalog into the assets system (blake3 dedup identity). Best-effort. + await dedup.register_completed(self.spec.dest_path) + + # ----- helpers ----- + + def _recompute_bytes_done(self) -> None: + self.state.bytes_done = sum(s.bytes_done for s in self.state.segments) + now = time.monotonic() + dt = now - self.state._last_time + if dt >= 0.5: + self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt + self.state._last_bytes = self.state.bytes_done + self.state._last_time = now + + async def _persist_progress(self, force: bool = False) -> None: + now = time.monotonic() + if not force and now - self._last_persist < _PERSIST_INTERVAL: + if self._notify: + self._notify(self.spec.download_id) + return + self._last_persist = now + queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done) + for seg in self.state.segments: + if seg.end >= seg.start: # skip unknown-size sentinel + queries.update_segment_progress( + self.spec.download_id, seg.idx, seg.bytes_done + ) + if self._notify: + self._notify(self.spec.download_id) + + async def _reset_for_restart(self) -> None: + await self._close_writer() + self._remove_temp() + for seg in self.state.segments: + seg.bytes_done = 0 + self.state.bytes_done = 0 + queries.update_download(self.spec.download_id, bytes_done=0) + if queries.list_segments(self.spec.download_id): + queries.replace_segments(self.spec.download_id, []) + + async def _close_writer(self) -> None: + if self._writer is not None: + try: + await self._writer.close() + except Exception: + logging.debug("[model_downloader] writer close error", exc_info=True) + self._writer = None + + def _remove_temp(self) -> None: + try: + os.remove(self.spec.temp_path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning( + "[model_downloader] could not remove %s: %s", self.spec.temp_path, e + ) + + def _set_status(self, status: str, error: Optional[str] = None) -> None: + self.state.status = status + if error is not None: + self.state.error = error + fields = {"status": status, "bytes_done": self.state.bytes_done} + if error is not None: + fields["error"] = error + if status == DownloadStatus.QUEUED: + fields["attempts"] = self.spec.attempts + 1 + self.spec.attempts += 1 + queries.update_download(self.spec.download_id, **fields) + if self._notify: + self._notify(self.spec.download_id) diff --git a/app/model_downloader/engine/planner.py b/app/model_downloader/engine/planner.py new file mode 100644 index 000000000..3ace350b0 --- /dev/null +++ b/app/model_downloader/engine/planner.py @@ -0,0 +1,51 @@ +"""Segment planning (PRD section 5.2). + +Split a known byte range into S roughly-equal segments, each fetched by its +own coroutine with ``Range: bytes=start-end``. Falls back to a single segment +when the server doesn't support ranges or the size is unknown/too small for +segmentation to be worthwhile. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +# Below this size, the per-connection setup cost outweighs any parallelism. +_MIN_SEGMENT_BYTES = 1 * 1024 * 1024 + + +@dataclass(frozen=True) +class SegmentPlan: + idx: int + start: int + end: int # inclusive + + @property + def length(self) -> int: + return self.end - self.start + 1 + + +def effective_segment_count( + total_bytes: int | None, accept_ranges: bool, configured: int +) -> int: + """How many segments to actually use for this file.""" + if not accept_ranges or total_bytes is None or total_bytes <= 0: + return 1 + by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES) + return max(1, min(configured, by_size)) + + +def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]: + """Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total).""" + if total_bytes <= 0 or num_segments <= 1: + return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))] + base = total_bytes // num_segments + plans: list[SegmentPlan] = [] + start = 0 + for i in range(num_segments): + # Last segment soaks up the remainder. + length = base if i < num_segments - 1 else total_bytes - start + end = start + length - 1 + plans.append(SegmentPlan(idx=i, start=start, end=end)) + start = end + 1 + return plans diff --git a/app/model_downloader/engine/writer.py b/app/model_downloader/engine/writer.py new file mode 100644 index 000000000..ada676475 --- /dev/null +++ b/app/model_downloader/engine/writer.py @@ -0,0 +1,61 @@ +"""Positioned, off-loop file writes (PRD section 4 + 5.2). + +Network I/O stays on the event loop; every blocking disk op (preallocate, +positioned write, fsync) is run in a bounded thread pool via +``run_in_executor`` so downloads never stall inference or the web server. + +A single file descriptor is opened for the whole download. Segments write to +their own offsets with ``os.pwrite`` — which is offset-addressed and atomic +per call, so concurrent segment writers need no extra locking. Per-chunk +fsync is avoided; we fsync once at completion. +""" + +from __future__ import annotations + +import asyncio +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +# One shared, bounded pool for all download disk I/O. +_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer") + + +class FileWriter: + """Owns the ``.part`` file descriptor for one download.""" + + def __init__(self, path: str) -> None: + self.path = path + self._fd: Optional[int] = None + + def _open(self) -> None: + os.makedirs(os.path.dirname(self.path), exist_ok=True) + self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT, 0o644) + + async def open(self) -> None: + await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open) + + async def preallocate(self, size: int) -> None: + """Grow the file to ``size`` so segments write to their offsets.""" + if self._fd is None or size <= 0: + return + await asyncio.get_running_loop().run_in_executor( + _EXECUTOR, os.ftruncate, self._fd, size + ) + + async def write_at(self, offset: int, data: bytes) -> None: + assert self._fd is not None, "writer not opened" + await asyncio.get_running_loop().run_in_executor( + _EXECUTOR, os.pwrite, self._fd, data, offset + ) + + async def flush(self) -> None: + if self._fd is None: + return + await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd) + + async def close(self) -> None: + if self._fd is None: + return + fd, self._fd = self._fd, None + await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd) diff --git a/app/model_downloader/manager.py b/app/model_downloader/manager.py new file mode 100644 index 000000000..5ce8176d2 --- /dev/null +++ b/app/model_downloader/manager.py @@ -0,0 +1,294 @@ +"""Public facade for the download manager (PRD section 10). + +This is the only object the server imports. It validates requests, owns the +:class:`Scheduler`, and exposes a small async API plus read models for status. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from typing import Callable, Optional + +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.scheduler import SCHEDULER +from app.model_downloader.security import paths +from app.model_downloader.security.allowlist import is_url_allowed +from app.model_downloader.security.paths import InvalidModelId + +# Non-terminal statuses: an existing row in one of these blocks a re-enqueue. +_LIVE_STATUSES = ( + DownloadStatus.QUEUED, + DownloadStatus.ACTIVE, + DownloadStatus.PAUSED, + DownloadStatus.VERIFYING, +) + + +class DownloadError(Exception): + """A user-facing error with a stable machine-readable code.""" + + def __init__(self, code: str, message: str, status: int = 400) -> None: + super().__init__(message) + self.code = code + self.message = message + self.http_status = status + + +class DownloadManager: + def __init__(self) -> None: + self._scheduler = SCHEDULER + self._notify_cb: Optional[Callable[[str], None]] = None + + def set_notify(self, cb: Optional[Callable[[str], None]]) -> None: + self._notify_cb = cb + self._scheduler.set_notify(cb) + + async def start(self) -> None: + await self._scheduler.start() + + # ----- enqueue ----- + + async def enqueue( + self, + url: str, + model_id: str, + *, + priority: int = 0, + expected_sha256: Optional[str] = None, + allow_any_extension: bool = False, + credential_id: Optional[str] = None, + ) -> str: + if not is_url_allowed(url, allow_any_extension): + raise DownloadError( + "URL_NOT_ALLOWED", + "URL is not on the download allowlist (host/scheme/extension).", + ) + try: + paths.parse_model_id(model_id, allow_any_extension) + dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension) + except InvalidModelId as e: + raise DownloadError("INVALID_MODEL_ID", str(e)) + + if await asyncio.to_thread( + paths.resolve_existing, model_id, allow_any_extension + ): + raise DownloadError( + "ALREADY_AVAILABLE", + f"Model already exists on disk: {model_id}", + status=409, + ) + if await self._has_live_download(model_id): + raise DownloadError( + "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress.", + status=409, + ) + + download_id = str(uuid.uuid4()) + await asyncio.to_thread( + queries.insert_download, + { + "id": download_id, + "url": url, + "model_id": model_id, + "dest_path": dest_path, + "temp_path": temp_path, + "status": DownloadStatus.QUEUED, + "priority": priority, + "expected_sha256": expected_sha256, + "credential_id": credential_id, + "allow_any_extension": allow_any_extension, + }, + ) + logging.info("[model_downloader] enqueued %s -> %s", url, model_id) + await self._scheduler.pump() + return download_id + + async def _has_live_download(self, model_id: str) -> bool: + rows = await asyncio.to_thread(queries.list_downloads) + return any( + r.model_id == model_id and r.status in _LIVE_STATUSES for r in rows + ) + + # ----- control ----- + + async def pause(self, download_id: str) -> None: + job = self._scheduler.get_job(download_id) + if job is not None: + job.request_pause() + return + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status == DownloadStatus.QUEUED: + await asyncio.to_thread( + queries.update_download, download_id, status=DownloadStatus.PAUSED + ) + + async def resume(self, download_id: str) -> None: + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status in (DownloadStatus.PAUSED, DownloadStatus.FAILED): + await asyncio.to_thread( + queries.update_download, + download_id, + status=DownloadStatus.QUEUED, + error=None, + ) + await self._scheduler.pump() + + async def cancel(self, download_id: str) -> None: + job = self._scheduler.get_job(download_id) + if job is not None: + job.request_cancel() + return + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + if row.status in _LIVE_STATUSES: + import os + + try: + os.remove(row.temp_path) + except OSError: + pass + await asyncio.to_thread( + queries.update_download, download_id, status=DownloadStatus.CANCELLED + ) + + async def set_priority(self, download_id: str, priority: int) -> None: + row = await asyncio.to_thread(queries.get_download, download_id) + if row is None: + raise DownloadError("NOT_FOUND", "No such download.", status=404) + await asyncio.to_thread( + queries.update_download, download_id, priority=priority + ) + # Admission-order only (PRD section 13 default); a higher priority is + # picked up the next time a slot frees. Pump in case a slot is free now. + await self._scheduler.pump() + + # ----- read models ----- + + def _view(self, row) -> dict: + """Combine the persisted row with live in-memory progress, if running.""" + job = self._scheduler.get_job(row.id) + bytes_done = row.bytes_done + total = row.total_bytes + speed = None + eta = None + segments = None + if job is not None: + st = job.state + bytes_done = st.bytes_done + total = st.total_bytes if st.total_bytes is not None else total + speed = st.speed_bps + eta = st.eta_seconds + segments = [ + {"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length} + for s in st.segments + if s.end >= s.start + ] + progress = (bytes_done / total) if total else None + return { + "download_id": row.id, + "model_id": row.model_id, + "url": row.url, + "status": row.status, + "priority": row.priority, + "total_bytes": total, + "bytes_done": bytes_done, + "progress": progress, + "speed_bps": speed, + "eta_seconds": eta, + "segments": segments, + "error": row.error, + "created_at": row.created_at, + "updated_at": row.updated_at, + } + + def _view_from_state(self, job) -> dict: + """Build a view purely from the live in-memory job state (no DB).""" + st = job.state + return { + "download_id": st.download_id, + "model_id": st.model_id, + "url": st.url, + "status": st.status, + "priority": st.priority, + "total_bytes": st.total_bytes, + "bytes_done": st.bytes_done, + "progress": st.progress, + "speed_bps": st.speed_bps, + "eta_seconds": st.eta_seconds, + "segments": [ + {"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length} + for s in st.segments + if s.end >= s.start + ], + "error": st.error, + } + + def status_sync(self, download_id: str) -> Optional[dict]: + """Synchronous status read for the websocket notify path. + + Uses live in-memory state when the job is running (no DB round-trip on + the hot path); falls back to a quick DB read otherwise. + """ + job = self._scheduler.get_job(download_id) + if job is not None: + return self._view_from_state(job) + row = queries.get_download(download_id) + return self._view(row) if row is not None else None + + async def status(self, download_id: str) -> Optional[dict]: + row = await asyncio.to_thread(queries.get_download, download_id) + return self._view(row) if row is not None else None + + async def list(self) -> list[dict]: + rows = await asyncio.to_thread(queries.list_downloads) + return [self._view(r) for r in rows] + + async def availability(self, models: dict[str, str]) -> dict[str, dict]: + """Bulk per-id ``{state, progress, ...}`` for the frontend poll. + + ``state`` is ``available`` (on disk), ``downloading`` (live row), or + ``missing``. Cheap: a path lookup plus an in-memory/DB status check. + """ + rows = await asyncio.to_thread(queries.list_downloads) + by_model: dict[str, object] = {} + for r in rows: + if r.status in _LIVE_STATUSES or r.model_id not in by_model: + by_model[r.model_id] = r + + out: dict[str, dict] = {} + for model_id, url in models.items(): + try: + exists = await asyncio.to_thread(paths.resolve_existing, model_id) + except InvalidModelId: + out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)} + continue + if exists: + out[model_id] = {"state": "available", "url_allowed": is_url_allowed(url)} + continue + row = by_model.get(model_id) + if row is not None and row.status in _LIVE_STATUSES: + view = self._view(row) + out[model_id] = { + "state": "downloading", + "url_allowed": is_url_allowed(url), + "download_id": view["download_id"], + "progress": view["progress"], + "bytes_done": view["bytes_done"], + "total_bytes": view["total_bytes"], + "speed_bps": view["speed_bps"], + } + else: + out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)} + return out + + +DOWNLOAD_MANAGER = DownloadManager() diff --git a/app/model_downloader/net/http.py b/app/model_downloader/net/http.py new file mode 100644 index 000000000..33309b0ce --- /dev/null +++ b/app/model_downloader/net/http.py @@ -0,0 +1,110 @@ +"""Manual, validated redirect-following request opener. + +Automatic redirects are disabled (PRD section 9.2): we follow hops ourselves +so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL, +(b) recompute which stored credential — if any — applies to that hop's host, +and (c) let the connector's resolver screen the IP. This is the single place +that attaches credentials, so a token can never ride a redirect to a CDN host. +""" + +from __future__ import annotations + +import logging +from contextlib import asynccontextmanager +from typing import AsyncIterator, Optional +from urllib.parse import urljoin, urlsplit, urlunsplit + +import aiohttp + +from app.model_downloader.credentials.resolver import resolve_auth_for_hop +from app.model_downloader.net.session import get_session +from app.model_downloader.security.ssrf import ( + MAX_REDIRECTS, + SSRFError, + check_redirect_hop, +) + +_REDIRECT_CODES = {301, 302, 303, 307, 308} +DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120) + + +def redact_url(url: str) -> str: + """Drop the query string so a query-scheme secret is never logged/stored.""" + try: + parts = urlsplit(url) + except ValueError: + return "" + return urlunsplit(parts._replace(query="")) + + +async def _resolve_final_response( + method: str, + url: str, + credential_id: Optional[str], + base_headers: dict[str, str], + timeout: aiohttp.ClientTimeout, +) -> tuple[aiohttp.ClientResponse, str]: + """Follow redirects manually until a non-redirect response. + + Each intermediate redirect response is released before the next hop. + Returns the final ``(response, final_url)``; the caller owns releasing it. + """ + session = await get_session() + current = url + hops = 0 + while True: + check_redirect_hop(current) + parts = urlsplit(current) + auth = await resolve_auth_for_hop( + parts.hostname or "", parts.scheme, explicit_credential_id=credential_id + ) + req_headers = dict(base_headers) + req_url = current + if auth is not None: + req_headers.update(auth.headers) + req_url = auth.apply_to_url(current) + + resp = await session.request( + method, + req_url, + allow_redirects=False, + headers=req_headers, + timeout=timeout, + ) + if resp.status in _REDIRECT_CODES and resp.headers.get("Location"): + next_url = urljoin(str(resp.url), resp.headers["Location"]) + await resp.release() + hops += 1 + if hops > MAX_REDIRECTS: + raise SSRFError( + f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}" + ) + current = next_url + continue + return resp, redact_url(str(resp.url)) + + +@asynccontextmanager +async def open_validated( + method: str, + url: str, + *, + credential_id: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT, +) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]: + """Open ``method url`` following redirects manually and validated. + + Yields ``(response, final_url)`` where ``final_url`` is redacted of any + query string. The response is released automatically on exit. + """ + resp, final_url = await _resolve_final_response( + method, url, credential_id, dict(headers or {}), timeout + ) + try: + yield resp, final_url + finally: + try: + await resp.release() + except Exception: # pragma: no cover - best-effort cleanup + logging.debug("[model_downloader] response release error", exc_info=True) diff --git a/app/model_downloader/net/probe.py b/app/model_downloader/net/probe.py new file mode 100644 index 000000000..289aa0c74 --- /dev/null +++ b/app/model_downloader/net/probe.py @@ -0,0 +1,90 @@ +"""Pre-download probe (PRD section 5.1). + +Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a +range-support test — to discover ``Content-Length``, ``Accept-Ranges``, +``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace +LFS files the true size also appears in the non-standard ``X-Linked-Size`` +header, which we read as a fallback. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Optional + +import aiohttp + +from app.model_downloader.net.http import open_validated +from app.model_downloader.net.session import parse_int_header + +_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30) + + +@dataclass +class ProbeResult: + ok: bool + status: int + final_url: Optional[str] = None + total_bytes: Optional[int] = None + accept_ranges: bool = False + etag: Optional[str] = None + last_modified: Optional[str] = None + gated: bool = False # 401/403 — needs (or has wrong) credentials + error: Optional[str] = None + + +def _total_from_content_range(value: Optional[str]) -> Optional[int]: + # "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None + if not value or "/" not in value: + return None + total = value.rsplit("/", 1)[1].strip() + return parse_int_header(total) + + +async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult: + """Probe ``url`` and return discovered metadata, failing soft.""" + try: + async with open_validated( + "GET", + url, + credential_id=credential_id, + headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"}, + timeout=_PROBE_TIMEOUT, + ) as (resp, final_url): + if resp.status in (401, 403): + return ProbeResult( + ok=False, status=resp.status, final_url=final_url, gated=True, + error=f"host returned {resp.status} (authentication required)", + ) + if resp.status not in (200, 206): + return ProbeResult( + ok=False, status=resp.status, final_url=final_url, + error=f"probe returned HTTP {resp.status}", + ) + + headers = resp.headers + accept_ranges = False + total: Optional[int] = None + if resp.status == 206: + accept_ranges = True + total = _total_from_content_range(headers.get("Content-Range")) + else: # 200: server ignored the range + accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes" + total = parse_int_header(headers.get("Content-Length")) + + if total is None: + total = parse_int_header(headers.get("X-Linked-Size")) + + return ProbeResult( + ok=True, + status=resp.status, + final_url=final_url, + total_bytes=total, + accept_ranges=accept_ranges, + etag=headers.get("ETag"), + last_modified=headers.get("Last-Modified"), + ) + except Exception as e: # network / SSRF / timeout + logging.debug("[model_downloader] probe failed for %s: %s", url, e) + return ProbeResult(ok=False, status=0, error=f"{type(e).__name__}: {e}") diff --git a/app/model_downloader/net/session.py b/app/model_downloader/net/session.py new file mode 100644 index 000000000..368bd924a --- /dev/null +++ b/app/model_downloader/net/session.py @@ -0,0 +1,72 @@ +"""Lazily-created shared :class:`aiohttp.ClientSession`. + +A single session reuses TLS handshakes and TCP connections across the probe +and the many segment GETs to the same host (HuggingFace is the dominant +case), which is a large speedup on cold connections and exactly the +connection-reuse strategy that lets us match aria2c (PRD section 5.2). + +The connector uses :class:`ValidatingResolver` so every connection — initial +or post-redirect — is screened for private/special-use IPs at connect time. +TLS is pinned to certifi's CA bundle because the OS trust store is not wired +up on some Python installs (python.org macOS, slim containers). +""" + +from __future__ import annotations + +import asyncio +import ssl +from typing import Optional + +import aiohttp + +try: + import certifi + _CA_FILE = certifi.where() +except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp + _CA_FILE = None + +from comfy.cli_args import args +from app.model_downloader.security.ssrf import ValidatingResolver + +_session: Optional[aiohttp.ClientSession] = None +_lock = asyncio.Lock() + + +def ssl_context() -> ssl.SSLContext: + if _CA_FILE is not None: + return ssl.create_default_context(cafile=_CA_FILE) + return ssl.create_default_context() + + +async def get_session() -> aiohttp.ClientSession: + """Return the shared session, creating it on first use.""" + global _session + if _session is not None and not _session.closed: + return _session + async with _lock: + if _session is None or _session.closed: + connector = aiohttp.TCPConnector( + limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)), + ssl=ssl_context(), + resolver=ValidatingResolver(), + ) + _session = aiohttp.ClientSession(connector=connector) + return _session + + +async def close_session() -> None: + global _session + if _session is not None and not _session.closed: + await _session.close() + _session = None + + +def parse_int_header(value: Optional[str]) -> Optional[int]: + """Parse a non-negative integer header value, or None if bad/absent.""" + if not value: + return None + try: + n = int(value) + except (TypeError, ValueError): + return None + return n if n >= 0 else None diff --git a/app/model_downloader/scheduler.py b/app/model_downloader/scheduler.py new file mode 100644 index 000000000..89043c2a9 --- /dev/null +++ b/app/model_downloader/scheduler.py @@ -0,0 +1,160 @@ +"""Priority scheduler + lifecycle (PRD sections 4, 6, 12). + +Owns the set of running jobs and admits queued downloads up to a global +concurrency limit (K), highest priority first, FIFO within a priority. Runs +entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing, +DB) is offloaded by the job/writer layers. + +On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a +previous run are reset to ``queued`` and resumed from persisted offsets, and +orphaned ``.part`` files with no live download row are swept. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import random +import time +from typing import Callable, Optional + +from comfy.cli_args import args +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.engine.job import DownloadJob, JobSpec +from app.model_downloader.security import paths + +# Backoff for retryable failures (PRD section 12). +_BACKOFF_BASE = 2.0 +_BACKOFF_CAP = 300.0 +_MAX_ATTEMPTS = 6 + + +class Scheduler: + def __init__(self) -> None: + self._jobs: dict[str, DownloadJob] = {} + self._tasks: dict[str, asyncio.Task] = {} + self._backoff_until: dict[str, float] = {} + self._pump_lock = asyncio.Lock() + self._notify_cb: Optional[Callable[[str], None]] = None + self._started = False + + @property + def max_active(self) -> int: + return max(1, getattr(args, "download_max_active", 3)) + + def set_notify(self, cb: Optional[Callable[[str], None]]) -> None: + self._notify_cb = cb + + def get_job(self, download_id: str) -> Optional[DownloadJob]: + return self._jobs.get(download_id) + + def is_active(self, download_id: str) -> bool: + return download_id in self._tasks + + # ----- startup ----- + + async def start(self) -> None: + if self._started: + return + self._started = True + try: + await asyncio.to_thread(queries.reconcile_live_downloads) + await asyncio.to_thread(self._sweep_orphan_temp_files) + except Exception as e: + logging.warning("[model_downloader] startup reconcile failed: %s", e) + await self.pump() + + @staticmethod + def _sweep_orphan_temp_files() -> None: + """Remove ``.part`` files not referenced by a resumable download row. + + Resumable partials (queued/paused rows) are preserved; only truly + orphaned temp files from crashed runs are deleted. + """ + live = { + row.temp_path + for row in queries.list_downloads() + if row.status in (DownloadStatus.QUEUED, DownloadStatus.PAUSED) + } + for path in paths.iter_all_tmp_paths(): + if path in live: + continue + try: + os.remove(path) + logging.info("[model_downloader] removed orphan temp file: %s", path) + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + # ----- admission ----- + + async def pump(self) -> None: + async with self._pump_lock: + slots = self.max_active - len(self._tasks) + if slots <= 0: + return + now = time.monotonic() + candidates = await asyncio.to_thread(queries.list_queued_downloads) + for row in candidates: + if slots <= 0: + break + if row.id in self._tasks: + continue + if self._backoff_until.get(row.id, 0.0) > now: + continue + self._admit(row) + slots -= 1 + + def _admit(self, row) -> None: + spec = JobSpec( + download_id=row.id, + url=row.url, + model_id=row.model_id, + dest_path=row.dest_path, + temp_path=row.temp_path, + priority=row.priority, + credential_id=row.credential_id, + expected_sha256=row.expected_sha256, + allow_any_extension=row.allow_any_extension, + etag=row.etag, + attempts=row.attempts, + ) + job = DownloadJob(spec, notify_cb=self._notify_cb) + self._jobs[row.id] = job + self._tasks[row.id] = asyncio.ensure_future(self._run_job(job)) + + async def _run_job(self, job: DownloadJob) -> None: + download_id = job.spec.download_id + status = DownloadStatus.FAILED + try: + status = await job.run() + except Exception as e: # run() is defensive, but never let a task die silently + logging.error("[model_downloader] job %s crashed: %s", download_id, e) + finally: + self._tasks.pop(download_id, None) + self._jobs.pop(download_id, None) + + if status == DownloadStatus.QUEUED: + if job.spec.attempts >= _MAX_ATTEMPTS: + queries.update_download( + download_id, + status=DownloadStatus.FAILED, + error=f"giving up after {job.spec.attempts} attempts", + ) + if self._notify_cb: + self._notify_cb(download_id) + else: + delay = min( + _BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts + ) + random.uniform(0, 1.0) + self._backoff_until[download_id] = time.monotonic() + delay + asyncio.ensure_future(self._delayed_pump(delay)) + await self.pump() + + async def _delayed_pump(self, delay: float) -> None: + await asyncio.sleep(delay) + await self.pump() + + +SCHEDULER = Scheduler() diff --git a/app/model_downloader/security/allowlist.py b/app/model_downloader/security/allowlist.py new file mode 100644 index 000000000..b01c7f8de --- /dev/null +++ b/app/model_downloader/security/allowlist.py @@ -0,0 +1,84 @@ +"""URL allowlist for server-side model fetches (PRD section 9.1). + +Default-deny. A URL is downloadable only when its parsed host + scheme are +allowlisted AND (unless explicitly relaxed) its final filename ends in a +known model extension. + +The built-in host defaults mirror the frontend's ``isModelDownloadable`` +allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts`` +extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname`` +(never a raw string prefix) so userinfo tricks like +``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the +metadata IP — cannot slip past. +""" + +from __future__ import annotations + +from urllib.parse import urlparse + +from comfy.cli_args import args + +# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai / +# localhost). Extra hosts from --download-allowed-hosts are https-only. +_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = { + "huggingface.co": {"https"}, + "civitai.com": {"https"}, + "localhost": {"http", "https"}, + "127.0.0.1": {"http", "https"}, +} + +# Hosts for which loopback addresses are intentionally permitted (the localhost +# "download a local model" feature). Every other host's loopback resolution is +# rejected by the SSRF resolver. +LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"}) + +# Known model file extensions (frontend parity). Checked on the final filename. +ALLOWED_MODEL_EXTENSIONS = ( + ".safetensors", + ".sft", + ".ckpt", + ".pth", + ".pt", + ".gguf", + ".bin", +) + + +def _allowed_hosts() -> dict[str, set[str]]: + hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()} + for extra in getattr(args, "download_allowed_hosts", []) or []: + host = extra.strip().lower() + if host: + hosts.setdefault(host, set()).add("https") + return hosts + + +def is_host_allowed(host: str | None, scheme: str | None) -> bool: + """True iff ``host`` is allowlisted for ``scheme``. + + Used both for the initial URL and re-checked on every redirect hop + (PRD section 9.2), so a whitelisted URL cannot 30x into an off-list host. + """ + if not host or not scheme: + return False + allowed = _allowed_hosts().get(host.lower()) + return allowed is not None and scheme.lower() in allowed + + +def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool: + if allow_any_extension: + return True + return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS) + + +def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool: + """Check whether ``url`` is permitted as a server-side download source.""" + if not isinstance(url, str) or not url: + return False + try: + parsed = urlparse(url) + except ValueError: + return False + if not is_host_allowed(parsed.hostname, parsed.scheme): + return False + return has_allowed_extension(parsed.path, allow_any_extension) diff --git a/app/model_downloader/security/paths.py b/app/model_downloader/security/paths.py new file mode 100644 index 000000000..883cb9b9c --- /dev/null +++ b/app/model_downloader/security/paths.py @@ -0,0 +1,110 @@ +"""Path resolution + traversal safety for downloads (PRD section 9.3). + +A ``model_id`` is a *relative destination path* of the form +``/`` (e.g. ``loras/my_lora.safetensors``). This module +turns one into an absolute on-disk path under one of ComfyUI's registered +model folders, rejecting unknown folders, path traversal, and symlink escape. +This is the only thing that composes destination paths, so the engine never +touches user-supplied path strings directly. +""" + +from __future__ import annotations + +import os +import re +from typing import Iterator, Optional + +import folder_paths + +from app.model_downloader.constants import TMP_SUFFIX +from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS + +# A model_id component is a single path segment of safe characters — no slashes, +# no "..", no leading dots that could escape the target directory. +_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") + + +class InvalidModelId(ValueError): + """Raised when a model_id is malformed or names an unknown model folder.""" + + +def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]: + """Split ``/`` and validate both components. + + Returns ``(directory, filename)``. Does not touch the filesystem. + """ + if not isinstance(model_id, str) or "/" not in model_id: + raise InvalidModelId( + f"model_id must be '/', got {model_id!r}" + ) + directory, _, filename = model_id.partition("/") + if "/" in filename or not directory or not filename: + raise InvalidModelId( + f"model_id must have exactly one '/' separator, got {model_id!r}" + ) + if not _SEGMENT_RE.match(directory): + raise InvalidModelId(f"invalid directory segment {directory!r}") + if not _SEGMENT_RE.match(filename): + raise InvalidModelId(f"invalid filename segment {filename!r}") + if not allow_any_extension and not filename.lower().endswith( + ALLOWED_MODEL_EXTENSIONS + ): + raise InvalidModelId( + f"filename must end with a known model extension " + f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}" + ) + if directory not in folder_paths.folder_names_and_paths: + raise InvalidModelId(f"unknown model folder {directory!r}") + return directory, filename + + +def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]: + """Return the absolute path of an installed model, or None if missing. + + Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``. + """ + directory, filename = parse_model_id(model_id, allow_any_extension) + return folder_paths.get_full_path(directory, filename) + + +def resolve_destination( + model_id: str, allow_any_extension: bool = False +) -> tuple[str, str]: + """Return ``(final_path, temp_path)`` for a download. + + Downloads land at the first registered path for the model's directory + (the "primary" location). ``temp_path`` is a sibling ``.part`` file that + is atomically renamed onto ``final_path`` on success. The result is + asserted to stay within the registered root (defence in depth on top of + the segment regex). + """ + directory, filename = parse_model_id(model_id, allow_any_extension) + roots = folder_paths.get_folder_paths(directory) + if not roots: + raise InvalidModelId(f"no on-disk path registered for folder {directory!r}") + root = os.path.realpath(roots[0]) + final_path = os.path.realpath(os.path.join(root, filename)) + if final_path != root and not final_path.startswith(root + os.sep): + raise InvalidModelId(f"resolved path escapes model root: {model_id!r}") + temp_path = f"{final_path}{TMP_SUFFIX}" + return final_path, temp_path + + +def iter_all_tmp_paths() -> Iterator[str]: + """Yield this subsystem's temp files under every registered model folder. + + Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep + can never delete temp files created by other tools. + """ + seen_roots: set[str] = set() + for directory in list(folder_paths.folder_names_and_paths.keys()): + for root in folder_paths.get_folder_paths(directory): + if root in seen_roots or not os.path.isdir(root): + continue + seen_roots.add(root) + try: + for entry in os.scandir(root): + if entry.is_file() and entry.name.endswith(TMP_SUFFIX): + yield entry.path + except OSError: + continue diff --git a/app/model_downloader/security/ssrf.py b/app/model_downloader/security/ssrf.py new file mode 100644 index 000000000..4cb63715e --- /dev/null +++ b/app/model_downloader/security/ssrf.py @@ -0,0 +1,111 @@ +"""SSRF / exfiltration defenses (PRD section 9.2). + +Two cooperating layers: + +1. :class:`ValidatingResolver` is installed on the shared connector. Every + connection — the initial probe and every segment GET, including ones made + after a redirect — resolves its host through this resolver, which rejects + any address that lands on a private / special-use IP range. Because the + resolve and the connect happen together inside the connector, there is no + check-then-connect window for DNS rebinding to exploit. + +2. :func:`check_redirect_hop` re-validates every redirect hop. The host + allowlist gates only the *initial* user-supplied URL (anti-SSRF for + arbitrary input); legitimate downloads from allowlisted origins redirect + to presigned CDN hosts that are deliberately NOT on the allowlist (HF -> + ``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are + instead screened for scheme, embedded credentials, and — via the resolver + above — private IPs. Credentials are only ever attached when a hop's host + exactly matches a stored credential, so they are dropped on the CDN hop. +""" + +from __future__ import annotations + +import ipaddress +import socket +from urllib.parse import urlparse + +from aiohttp.abc import AbstractResolver +from aiohttp.resolver import DefaultResolver + +from app.model_downloader.security.allowlist import LOOPBACK_HOSTS + +# Cap the redirect chain length and the schemes a hop may use. +MAX_REDIRECTS = 5 +ALLOWED_SCHEMES = ("https", "http") + + +class SSRFError(Exception): + """A hop failed an SSRF / allowlist check.""" + + +def is_blocked_ip(ip_str: str) -> bool: + """True for any address we refuse to connect to. + + Covers loopback, link-local (incl. 169.254.169.254 cloud metadata), + RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::), + multicast and other reserved ranges. + """ + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + return True # unparseable -> refuse + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + or ip.is_unspecified + ) + + +class ValidatingResolver(AbstractResolver): + """Delegating resolver that drops blocked IPs from every resolution. + + If a hostname resolves only to blocked addresses, the connection fails + closed with an :class:`OSError`, which aiohttp surfaces as a connection + error to the caller. + """ + + def __init__(self) -> None: + self._inner = DefaultResolver() + + async def resolve(self, host, port=0, family=socket.AF_INET): + infos = await self._inner.resolve(host, port, family) + # localhost/127.0.0.1 are an explicit, opt-in allowlist feature. + if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS: + return infos + safe = [info for info in infos if not is_blocked_ip(info["host"])] + if not safe: + raise OSError( + f"refusing to connect to {host!r}: resolves only to " + f"private/special-use addresses" + ) + return safe + + async def close(self) -> None: + await self._inner.close() + + +def check_redirect_hop(url: str) -> str: + """Validate one redirect hop's URL. + + Returns the URL unchanged on success; raises :class:`SSRFError` otherwise. + Enforces an allowed scheme and forbids credentials-in-URL. The host is NOT + re-checked against the allowlist (CDN redirect targets are off-list by + design); private-IP protection is provided by the connector's resolver, + and credential leakage is prevented by exact host matching at attach time. + The landing filename's extension is gated separately by the caller. + """ + try: + parsed = urlparse(url) + except ValueError as e: + raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e + if parsed.scheme.lower() not in ALLOWED_SCHEMES: + raise SSRFError(f"redirect to disallowed scheme {parsed.scheme!r}") + if parsed.username or parsed.password: + raise SSRFError("credentials-in-URL are not allowed") + if not parsed.hostname: + raise SSRFError(f"redirect URL has no host: {url!r}") + return url diff --git a/app/model_downloader/verify/checksum.py b/app/model_downloader/verify/checksum.py new file mode 100644 index 000000000..fccced2e5 --- /dev/null +++ b/app/model_downloader/verify/checksum.py @@ -0,0 +1,49 @@ +"""Hub-checksum verification = SHA256 (PRD section 8.1). + +Only used to confirm a download matches a *provided* ``expected_sha256``. It +is NOT the dedup key (that is blake3, owned by the assets system). The full +sequential read happens at most once, here, only when a checksum was supplied. +""" + +from __future__ import annotations + +import hashlib +from typing import Callable, Optional + +_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +class ChecksumError(Exception): + """The computed SHA256 did not match the expected value.""" + + +def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]: + """Stream the file and return its lowercase hex SHA256. + + Returns ``None`` if interrupted via ``interrupt_check``. + """ + h = hashlib.sha256() + with open(path, "rb") as f: + while True: + if interrupt_check is not None and interrupt_check(): + return None + chunk = f.read(_CHUNK) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + +def verify_sha256( + path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None +) -> None: + """Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``.""" + actual = sha256_file(path, interrupt_check) + if actual is None: + return # interrupted; caller will re-verify on resume + if actual.lower() != expected.lower(): + raise ChecksumError( + f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}" + ) diff --git a/app/model_downloader/verify/dedup.py b/app/model_downloader/verify/dedup.py new file mode 100644 index 000000000..7b48a95d5 --- /dev/null +++ b/app/model_downloader/verify/dedup.py @@ -0,0 +1,53 @@ +"""Dedup + catalog handoff — reuse the assets system (PRD section 8.5). + +We do NOT build a parallel indexer. "Do I already have it?" is answered by +``resolve_existing`` (path) at enqueue time and, where a hash is known, by the +assets blake3 catalog. After a completed download we register the file +through the assets ingest path so it is cataloged and (eventually) hashed by +the existing enrichment worker. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + + +def _register_sync(abs_path: str) -> Optional[str]: + """Register a finished file into the assets catalog. Returns asset hash.""" + try: + from app.assets.services.ingest import register_file_in_place + except Exception as e: # assets package import failure — non-fatal + logging.debug("[model_downloader] assets ingest unavailable: %s", e) + return None + try: + result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[]) + return result.asset.hash if result and result.asset else None + except Exception as e: + # The file is already safely on disk; cataloging is best-effort. + logging.warning( + "[model_downloader] could not register %s into assets catalog: %s", + abs_path, e, + ) + return None + + +async def register_completed(abs_path: str) -> Optional[str]: + """Catalog a completed download via the assets system (off the event loop).""" + return await asyncio.to_thread(_register_sync, abs_path) + + +def _find_by_hash_sync(blake3_hex: str) -> Optional[str]: + try: + from app.assets.services.asset_management import get_asset_by_hash + except Exception: + return None + asset = get_asset_by_hash("blake3:" + blake3_hex) + return asset.hash if asset is not None else None + + +async def find_existing_by_hash(blake3_hex: str) -> Optional[str]: + """Pure DB lookup — never triggers hashing on the hot path.""" + return await asyncio.to_thread(_find_by_hash_sync, blake3_hex) diff --git a/app/model_downloader/verify/structural.py b/app/model_downloader/verify/structural.py new file mode 100644 index 000000000..5575ff145 --- /dev/null +++ b/app/model_downloader/verify/structural.py @@ -0,0 +1,65 @@ +"""Cheap structural validation, no full read (PRD section 8.2). + +For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries +the tensor table and the byte length of the data region. We assert +``file_size == 8 + header_len + data_region_len``. This detects truncation +and most corruption for free, before any crypto hashing. Other extensions +have no cheap structural check and pass through. +""" + +from __future__ import annotations + +import json +import os +import struct + +_SAFETENSORS_EXTS = (".safetensors", ".sft") +# A sane upper bound so a corrupt header length can't make us read gigabytes. +_MAX_HEADER_BYTES = 100 * 1024 * 1024 + + +class StructuralError(Exception): + """The file failed its structural integrity check.""" + + +def validate(path: str) -> None: + """Validate the file at ``path``. Raises :class:`StructuralError` on failure.""" + lower = path.lower() + if lower.endswith(_SAFETENSORS_EXTS): + _validate_safetensors(path) + # No structural check for other formats; the size + (optional) checksum + # gates in the engine cover those. + + +def _validate_safetensors(path: str) -> None: + file_size = os.path.getsize(path) + if file_size < 8: + raise StructuralError(f"file too small to be safetensors ({file_size} bytes)") + with open(path, "rb") as f: + header_len = struct.unpack(" _MAX_HEADER_BYTES: + raise StructuralError(f"implausible safetensors header length {header_len}") + if 8 + header_len > file_size: + raise StructuralError("safetensors header extends past end of file") + try: + header = json.loads(f.read(header_len).decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as e: + raise StructuralError(f"safetensors header is not valid JSON: {e}") from e + + data_len = 0 + for name, entry in header.items(): + if name == "__metadata__": + continue + if not isinstance(entry, dict) or "data_offsets" not in entry: + raise StructuralError(f"tensor {name!r} missing data_offsets") + offsets = entry["data_offsets"] + if not (isinstance(offsets, list) and len(offsets) == 2): + raise StructuralError(f"tensor {name!r} has malformed data_offsets") + data_len = max(data_len, int(offsets[1])) + + expected = 8 + header_len + data_len + if file_size != expected: + raise StructuralError( + f"size mismatch: file is {file_size} bytes, header implies {expected} " + f"(8 + {header_len} header + {data_len} data)" + ) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e3099a230..9130d5f15 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -243,6 +243,14 @@ parser.add_argument("--enable-assets", action="store_true", help="Enable the ass parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button") parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.") +# ----- Model download manager (PRD: docs/prd-download-manager.md) ----- +parser.add_argument("--download-segments", type=int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).") +parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).") +parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).") +parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).") +parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.") +parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0f30608a9..4ce977c20 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -104,6 +104,7 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = { "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, "assets": args.enable_assets, + "server_side_model_downloads": True, } # CLI-provided flags cannot overwrite core flags diff --git a/openapi.yaml b/openapi.yaml index c6a8621cc..f731d374b 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -230,6 +230,93 @@ components: - base_version - workflow_json type: object + DownloadEnqueueRequest: + description: Request body for enqueuing a server-side model download. + properties: + allow_any_extension: + default: false + description: Permit a non-model file extension (default only allows known model extensions). + type: boolean + credential_id: + description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match. + nullable: true + type: string + expected_sha256: + description: Optional hub-provided SHA256 to verify the completed file against (fail-closed). + nullable: true + type: string + model_id: + description: Destination as "/", resolving to a registered model folder (e.g. "loras/my_lora.safetensors"). + type: string + priority: + default: 0 + description: Scheduling priority; higher is admitted first. + type: integer + url: + description: Source URL; must be on the allowlist (host + scheme + extension). + type: string + required: + - url + - model_id + type: object + DownloadStatus: + description: Current state and live progress of a single download. + properties: + bytes_done: + type: integer + created_at: + type: integer + download_id: + format: uuid + type: string + error: + nullable: true + type: string + eta_seconds: + nullable: true + type: number + model_id: + type: string + priority: + type: integer + progress: + description: Fraction in [0,1]; null until total size is known. + nullable: true + type: number + segments: + description: Per-segment progress (segmented downloads only). + items: + properties: + bytes_done: + type: integer + idx: + type: integer + length: + type: integer + type: object + nullable: true + type: array + speed_bps: + nullable: true + type: number + status: + enum: + - queued + - active + - paused + - verifying + - completed + - failed + - cancelled + type: string + total_bytes: + nullable: true + type: integer + updated_at: + type: integer + url: + type: string + type: object ErrorResponse: description: Standard error response with a machine-readable code and human-readable message. properties: @@ -511,6 +598,78 @@ components: required: - history type: object + HostCredentialUpsert: + description: Request body for upserting a per-host credential. The secret is write-only. + properties: + auth_scheme: + default: bearer + description: How the secret is attached to requests. + enum: + - bearer + - header + - query + type: string + enabled: + default: true + type: boolean + header_name: + description: Header name when auth_scheme=header (defaults to Authorization). + nullable: true + type: string + host: + description: Normalized hostname the key applies to (e.g. "civitai.com"). + type: string + label: + description: User-friendly name for display. + nullable: true + type: string + match_subdomains: + default: false + description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs). + type: boolean + query_param: + description: Query parameter name when auth_scheme=query. + nullable: true + type: string + secret: + description: The API key. Write-only — never returned by any endpoint. + type: string + required: + - host + - secret + type: object + HostCredentialView: + description: Masked, API-safe view of a stored credential. Never includes the secret. + properties: + auth_scheme: + type: string + created_at: + type: integer + enabled: + type: boolean + header_name: + nullable: true + type: string + host: + type: string + id: + format: uuid + type: string + label: + nullable: true + type: string + match_subdomains: + type: boolean + query_param: + nullable: true + type: string + secret_last4: + description: Last 4 characters of the secret, for masked display only. + nullable: true + type: string + updated_at: + type: integer + type: object JobCancelResponse: description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops. properties: @@ -2350,6 +2509,330 @@ paths: summary: Get tag histogram for filtered assets tags: - file + /api/download: + get: + description: List all known downloads (queued, active, paused, and terminal) with live progress. + operationId: listDownloads + responses: + "200": + content: + application/json: + schema: + properties: + downloads: + items: + $ref: '#/components/schemas/DownloadStatus' + type: array + type: object + description: List of downloads + summary: List downloads + tags: + - download + /api/download/availability: + post: + description: | + Bulk per-id availability for a set of model_ids declared in a workflow. + Returns whether each model is available on disk, currently downloading + (with progress), or missing, plus whether its URL is on the allowlist. + operationId: getModelsAvailability + requestBody: + content: + application/json: + schema: + properties: + models: + additionalProperties: + type: string + description: Map of "/" model_id to its declared source URL. + type: object + type: object + responses: + "200": + content: + application/json: + schema: + properties: + models: + additionalProperties: true + type: object + type: object + description: Per-id availability map + summary: Bulk model availability + status + tags: + - download + /api/download/credentials: + get: + description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label). + operationId: listDownloadCredentials + responses: + "200": + content: + application/json: + schema: + properties: + credentials: + items: + $ref: '#/components/schemas/HostCredentialView' + type: array + type: object + description: Masked credential list + summary: List host credentials (masked) + tags: + - download + post: + description: | + Upsert (by host) a per-host API key used to authenticate downloads. + The secret is write-only: it is stored once here and never returned by any endpoint. + operationId: upsertDownloadCredential + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/HostCredentialUpsert' + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/HostCredentialView' + description: Credential stored (masked view returned) + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid credential + summary: Upsert a host credential + tags: + - download + /api/download/credentials/{id}: + delete: + description: Delete a stored host credential. + operationId: deleteDownloadCredential + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + properties: + deleted: + type: boolean + type: object + description: Deleted + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such credential + summary: Delete a host credential + tags: + - download + get: + description: Get a single host credential (masked; never includes the secret). + operationId: getDownloadCredential + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HostCredentialView' + description: Masked credential + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such credential + summary: Get a host credential (masked) + tags: + - download + /api/download/enqueue: + post: + description: | + Enqueue a server-side model download. The URL must be on the allowlist + (host + scheme + extension) and the model_id must be "/" + resolving to a registered model folder. Returns immediately; track progress + via GET /api/download/{id} or the "download_progress" websocket event. + operationId: enqueueDownload + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DownloadEnqueueRequest' + responses: + "202": + content: + application/json: + schema: + properties: + accepted: + type: boolean + download_id: + format: uuid + type: string + type: object + description: Download accepted and queued + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request (bad URL, model_id, or not allowlisted) + "409": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Already on disk or already downloading + summary: Enqueue a model download + tags: + - download + /api/download/{id}: + get: + description: Get the current status + progress of a single download. + operationId: getDownloadStatus + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/DownloadStatus' + description: Download status + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: No such download + summary: Get download status + tags: + - download + /api/download/{id}/cancel: + post: + description: Cancel a download. The partial file is removed. + operationId: cancelDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Cancelled + summary: Cancel a download + tags: + - download + /api/download/{id}/pause: + post: + description: Pause a download. The partial file and per-segment offsets are retained for resume. + operationId: pauseDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Paused + summary: Pause a download + tags: + - download + /api/download/{id}/priority: + post: + description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees. + operationId: setDownloadPriority + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + requestBody: + content: + application/json: + schema: + properties: + priority: + type: integer + required: + - priority + type: object + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Priority updated + summary: Set download priority + tags: + - download + /api/download/{id}/resume: + post: + description: Resume a paused (or failed) download from its persisted offsets. + operationId: resumeDownload + parameters: + - in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + properties: + ok: + type: boolean + type: object + description: Resumed + summary: Resume a download + tags: + - download /api/embeddings: get: description: Returns the list of text-encoder embeddings available on disk. diff --git a/server.py b/server.py index 361850f38..fbf411a19 100644 --- a/server.py +++ b/server.py @@ -45,6 +45,8 @@ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal from app.assets.seeder import asset_seeder from app.assets.api.routes import register_assets_routes +from app.model_downloader.api.routes import register_routes as register_model_downloader_routes +from app.model_downloader.manager import DOWNLOAD_MANAGER from app.assets.services.ingest import register_file_in_place from app.assets.services.asset_management import resolve_hash_to_path @@ -256,6 +258,7 @@ class PromptServer(): else: register_assets_routes(self.app) asset_seeder.disable() + register_model_downloader_routes(self.app) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -1182,6 +1185,24 @@ class PromptServer(): async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout) + await self._setup_model_downloader() + + async def _setup_model_downloader(self): + """Start the download manager: push progress over the websocket and + resume any downloads interrupted by a previous run.""" + def _notify(download_id: str) -> None: + try: + view = DOWNLOAD_MANAGER.status_sync(download_id) + if view is not None: + self.send_sync("download_progress", view) + except Exception: + logging.debug("download progress notify failed", exc_info=True) + + DOWNLOAD_MANAGER.set_notify(_notify) + try: + await DOWNLOAD_MANAGER.start() + except Exception as e: + logging.warning("Failed to start model download manager: %s", e) def add_routes(self): self.user_manager.add_routes(self.routes) diff --git a/tests-unit/model_downloader_test/conftest.py b/tests-unit/model_downloader_test/conftest.py new file mode 100644 index 000000000..b285c3693 --- /dev/null +++ b/tests-unit/model_downloader_test/conftest.py @@ -0,0 +1,63 @@ +"""Shared fixtures for the model download manager tests. + +These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is +initialized once, a temp model folder is registered with ``folder_paths``, and +the shared aiohttp session is reset between tests so each async test gets a +session bound to its own event loop. +""" + +from __future__ import annotations + +import os +import tempfile + +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def _init_db(): + import app.database.db as db + from comfy.cli_args import args + + db_path = tempfile.mktemp(suffix="-dlmgr-test.sqlite3") + args.database_url = f"sqlite:///{db_path}" + db.init_db() + yield + try: + os.remove(db_path) + except OSError: + pass + + +@pytest.fixture(autouse=True) +def _reset_runtime(): + """Reset module singletons that hold event-loop-bound or cross-test state.""" + import app.model_downloader.net.session as ns + from app.model_downloader.scheduler import SCHEDULER + + ns._session = None + SCHEDULER._jobs.clear() + SCHEDULER._tasks.clear() + SCHEDULER._backoff_until.clear() + SCHEDULER._started = False + yield + ns._session = None + + +@pytest.fixture +def model_root(tmp_path): + """Register a temp 'loras' model folder and return its absolute path.""" + import folder_paths + + root = tmp_path / "loras" + root.mkdir(parents=True, exist_ok=True) + saved = folder_paths.folder_names_and_paths.get("loras") + folder_paths.folder_names_and_paths["loras"] = ( + [str(root)], + {".safetensors", ".sft", ".ckpt", ".pt", ".pth"}, + ) + yield str(root) + if saved is not None: + folder_paths.folder_names_and_paths["loras"] = saved + else: + folder_paths.folder_names_and_paths.pop("loras", None) diff --git a/tests-unit/model_downloader_test/test_credentials.py b/tests-unit/model_downloader_test/test_credentials.py new file mode 100644 index 000000000..199dfcabc --- /dev/null +++ b/tests-unit/model_downloader_test/test_credentials.py @@ -0,0 +1,110 @@ +"""Unit tests for the credential store and the per-hop credential resolver. + +Covers the critical rule: a secret is only ever attached when the current +hop's host matches a stored credential, and never over a non-https hop. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from app.model_downloader.credentials import resolver +from app.model_downloader.credentials.store import ( + CREDENTIAL_STORE, + CredentialValidationError, + normalize_host, +) +from app.model_downloader.database.models import HostCredential + + +# ----- pure host normalization + matching ----- + + +@pytest.mark.parametrize( + "raw,expected", + [ + ("Civitai.com", "civitai.com"), + ("HuggingFace.co:443", "huggingface.co"), + (" Example.COM ", "example.com"), + ], +) +def test_normalize_host(raw, expected): + assert normalize_host(raw) == expected + + +def _cred(**kw) -> HostCredential: + base = dict( + id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer", + secret="SECRET", enabled=True, + ) + base.update(kw) + return HostCredential(**base) + + +def test_matches_exact_only_by_default(): + c = _cred(host="civitai.com") + assert resolver._matches(c, "civitai.com") is True + assert resolver._matches(c, "api.civitai.com") is False + assert resolver._matches(c, "evil-civitai.com") is False + + +def test_matches_subdomain_label_boundary(): + c = _cred(host="example.com", match_subdomains=True) + assert resolver._matches(c, "api.example.com") is True + assert resolver._matches(c, "example.com") is True + # not a label boundary -> no match + assert resolver._matches(c, "evil-example.com") is False + + +def test_build_auth_shapes(): + assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == { + "Authorization": "Bearer SECRET" + } + assert resolver._build_auth( + _cred(auth_scheme="header", header_name="X-Api-Key") + ).headers == {"X-Api-Key": "SECRET"} + q = resolver._build_auth(_cred(auth_scheme="query", query_param="token")) + assert q.query == {"token": "SECRET"} + assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET" + + +# ----- DB-backed store + resolver ----- + + +def test_store_upsert_is_write_only_and_masked(): + async def _run(): + view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key") + # The view never carries the secret, only the last 4. + assert not hasattr(view, "secret") + assert view.secret_last4 == "1234" + assert view.host == "civitai.com" + listed = await CREDENTIAL_STORE.list() + assert any(v.host == "civitai.com" for v in listed) + await CREDENTIAL_STORE.delete(view.id) + asyncio.run(_run()) + + +def test_query_scheme_requires_param(): + async def _run(): + with pytest.raises(CredentialValidationError): + await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query") + asyncio.run(_run()) + + +def test_resolver_never_crosses_host_boundary(): + async def _run(): + view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key") + try: + # matching host over https -> attached + auth = await resolver.resolve_auth_for_hop("huggingface.co", "https") + assert auth is not None + assert auth.headers["Authorization"] == "Bearer hf_secret_key" + # CDN redirect host -> dropped + assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None + # non-https hop -> never attached + assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None + finally: + await CREDENTIAL_STORE.delete(view.id) + asyncio.run(_run()) diff --git a/tests-unit/model_downloader_test/test_engine_integration.py b/tests-unit/model_downloader_test/test_engine_integration.py new file mode 100644 index 000000000..8818d6c8a --- /dev/null +++ b/tests-unit/model_downloader_test/test_engine_integration.py @@ -0,0 +1,270 @@ +"""Integration tests for the download engine against a local aiohttp server. + +Covers single-stream and segmented transfers, deterministic resume from a +partial file, and cancel rollback. Async tests are driven via ``asyncio.run`` +so no pytest-asyncio plugin is required. +""" + +from __future__ import annotations + +import asyncio +import os +import uuid + +import pytest +from aiohttp import web + +from comfy.cli_args import args +from app.model_downloader.constants import DownloadStatus +from app.model_downloader.database import queries +from app.model_downloader.engine.job import DownloadJob, JobSpec +from app.model_downloader.net.session import close_session +from app.model_downloader.security import paths + +PAYLOAD_ETAG = '"v1"' + + +def _payload(n: int) -> bytes: + return bytes((i * 37 + 11) % 256 for i in range(n)) + + +def _range_handler(payload: bytes): + async def handler(request: web.Request) -> web.Response: + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + chunk = payload[start : end + 1] + return web.Response( + status=206, + body=chunk, + headers={ + "Content-Range": f"bytes {start}-{end}/{len(payload)}", + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + }, + ) + return web.Response( + status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} + ) + + return handler + + +def _noranges_handler(payload: bytes): + async def handler(request: web.Request) -> web.Response: + # Always full body, never advertises Accept-Ranges -> single-stream. + return web.Response(status=200, body=payload) + + return handler + + +def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01): + async def handler(request: web.Request) -> web.StreamResponse: + resp = web.StreamResponse( + status=200, headers={"Content-Length": str(len(payload))} + ) + await resp.prepare(request) + for i in range(0, len(payload), chunk): + await resp.write(payload[i : i + chunk]) + await asyncio.sleep(delay) + await resp.write_eof() + return resp + + return handler + + +async def _serve(handler): + app = web.Application() + app.router.add_route("*", "/{name:.*}", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] + return runner, port + + +def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]: + final_path, temp_path = paths.resolve_destination(model_id) + download_id = str(uuid.uuid4()) + queries.insert_download( + { + "id": download_id, + "url": url, + "model_id": model_id, + "dest_path": final_path, + "temp_path": temp_path, + "status": status, + } + ) + return download_id, final_path, temp_path + + +# ----- single-stream ----- + + +def test_single_stream_download(model_root): + payload = _payload(300_000) + + async def _run(): + await close_session() + runner, port = await _serve(_noranges_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, _temp = _insert("loras/single.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/single.safetensors", + dest_path=final_path, temp_path=_temp, + )) + status = await job.run() + assert status == DownloadStatus.COMPLETED, queries.get_download(did).error + assert os.path.exists(final_path) + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- segmented ----- + + +def test_segmented_download(model_root): + payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments + + async def _run(): + await close_session() + runner, port = await _serve(_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/seg.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/seg.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.COMPLETED, queries.get_download(did).error + assert open(final_path, "rb").read() == payload + # More than one segment row was planned. + assert len(queries.list_segments(did)) > 1 + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- deterministic resume from a partial file ----- + + +def test_resume_from_partial(model_root): + payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work + + async def _run(): + await close_session() + runner, port = await _serve(_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/resume.safetensors", url) + # Simulate a prior partial: first 200 KiB already written, offset persisted. + prefix = 200 * 1024 + os.makedirs(os.path.dirname(temp), exist_ok=True) + with open(temp, "wb") as f: + f.write(payload[:prefix]) + queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG) + + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/resume.safetensors", + dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG, + )) + status = await job.run() + assert status == DownloadStatus.COMPLETED, queries.get_download(did).error + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- cancel rollback ----- + + +def test_cancel_rollback(model_root, monkeypatch): + monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False) + payload = _payload(1024 * 1024) + + async def _run(): + await close_session() + runner, port = await _serve(_slow_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/cancel.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/cancel.safetensors", + dest_path=final_path, temp_path=temp, + )) + task = asyncio.ensure_future(job.run()) + # Wait until some bytes have been written, then cancel. + for _ in range(200): + await asyncio.sleep(0.01) + if job.state.bytes_done > 0: + break + job.request_cancel() + status = await task + assert status == DownloadStatus.CANCELLED + assert not os.path.exists(temp) + assert not os.path.exists(final_path) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +# ----- manager + scheduler end-to-end ----- + + +def test_manager_enqueue_to_completion(model_root): + payload = _payload(2 * 1024 * 1024) + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER + + runner, port = await _serve(_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors") + # Wait for completion. + final_path, _ = paths.resolve_destination("loras/e2e.safetensors") + for _ in range(500): + await asyncio.sleep(0.02) + row = queries.get_download(did) + if row.status in DownloadStatus.TERMINAL: + break + row = queries.get_download(did) + assert row.status == DownloadStatus.COMPLETED, row.error + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_manager_rejects_disallowed_url(model_root): + async def _run(): + from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + + with pytest.raises(DownloadError) as ei: + await DOWNLOAD_MANAGER.enqueue( + "https://evil.example.com/x.safetensors", "loras/bad.safetensors" + ) + assert ei.value.code == "URL_NOT_ALLOWED" + + asyncio.run(_run()) diff --git a/tests-unit/model_downloader_test/test_planner_structural.py b/tests-unit/model_downloader_test/test_planner_structural.py new file mode 100644 index 000000000..33d4978af --- /dev/null +++ b/tests-unit/model_downloader_test/test_planner_structural.py @@ -0,0 +1,71 @@ +"""Unit tests for the segment planner and structural safetensors validation.""" + +from __future__ import annotations + +import json +import struct + +import pytest + +from app.model_downloader.engine.planner import ( + effective_segment_count, + plan_segments, +) +from app.model_downloader.verify import structural + + +# ----- planner ----- + + +def test_plan_segments_covers_full_range_contiguously(): + total = 1000 + plans = plan_segments(total, 4) + assert len(plans) == 4 + assert plans[0].start == 0 + assert plans[-1].end == total - 1 + # contiguous, no gaps/overlaps + for a, b in zip(plans, plans[1:]): + assert b.start == a.end + 1 + assert sum(p.length for p in plans) == total + + +def test_effective_segment_count_falls_back_to_single(): + # No range support -> single + assert effective_segment_count(10_000_000, False, 8) == 1 + # Unknown size -> single + assert effective_segment_count(None, True, 8) == 1 + # Tiny file -> fewer segments than configured + assert effective_segment_count(1024, True, 8) == 1 + # Large file with range support -> configured count + assert effective_segment_count(1_000_000_000, True, 8) == 8 + + +# ----- structural ----- + + +def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes: + header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}} + header_bytes = json.dumps(header).encode("utf-8") + body = b"\x00" * tensor_data_len + if corrupt_size: + body = body[:-1] # truncate one byte + return struct.pack(" refuse + ], +) +def test_is_blocked_ip(ip, blocked): + assert is_blocked_ip(ip) is blocked + + +# ----- SSRF: redirect hop validation ----- + + +def test_check_redirect_hop_rejects_bad_scheme_and_userinfo(): + with pytest.raises(SSRFError): + check_redirect_hop("ftp://huggingface.co/x.safetensors") + with pytest.raises(SSRFError): + check_redirect_hop("https://user:pass@cdn.example.com/x") + # A CDN host that is NOT on the allowlist is allowed as a redirect target + # (private-IP protection is the resolver's job; credential leak is prevented + # by exact host matching). + assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None + + +# ----- path safety ----- + + +def test_parse_model_id_valid(model_root): + directory, filename = paths.parse_model_id("loras/my_lora.safetensors") + assert directory == "loras" + assert filename == "my_lora.safetensors" + + +@pytest.mark.parametrize( + "model_id", + [ + "loras/../etc/passwd.safetensors", # traversal + "loras/sub/dir.safetensors", # nested + "unknownfolder/x.safetensors", # unknown folder + "loras/model.txt", # bad extension + "noslash.safetensors", # missing directory + "loras/", # empty filename + ], +) +def test_parse_model_id_rejects(model_root, model_id): + with pytest.raises(paths.InvalidModelId): + paths.parse_model_id(model_id) + + +def test_resolve_destination_stays_in_root(model_root): + final_path, temp_path = paths.resolve_destination("loras/x.safetensors") + assert final_path.startswith(model_root) + assert temp_path.startswith(model_root) + assert temp_path != final_path