diff --git a/app/model_downloader/scheduler.py b/app/model_downloader/scheduler.py index 098bc76ec..b41d2cb3b 100644 --- a/app/model_downloader/scheduler.py +++ b/app/model_downloader/scheduler.py @@ -141,6 +141,13 @@ class Scheduler: status = await job.run() except Exception as e: # run() is defensive, but never let a task die silently logging.error("[model_downloader] job %s crashed: %s", download_id, e) + queries.update_download( + download_id, + status=DownloadStatus.FAILED, + error=f"internal error: {e}", + ) + if self._notify_cb: + self._notify_cb(download_id) finally: self._tasks.pop(download_id, None) self._jobs.pop(download_id, None) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5de165cf0..77ed0d49d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -44,6 +44,17 @@ def _positive_int(value: str) -> int: return ivalue +def _non_negative_int(value: str) -> int: + """argparse type that rejects negatives but allows zero (a disable sentinel).""" + try: + ivalue = int(value) + except ValueError: + raise argparse.ArgumentTypeError(f"{value!r} is not an integer") + if ivalue < 0: + raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)") + return ivalue + + parser = argparse.ArgumentParser() parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") @@ -259,7 +270,7 @@ parser.add_argument("--download-segments", type=_positive_int, default=8, metava parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).") parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).") parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).") -parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).") +parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).") parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.") parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).") diff --git a/tests-unit/model_downloader_test/conftest.py b/tests-unit/model_downloader_test/conftest.py index b285c3693..33165294b 100644 --- a/tests-unit/model_downloader_test/conftest.py +++ b/tests-unit/model_downloader_test/conftest.py @@ -8,18 +8,44 @@ session bound to its own event loop. from __future__ import annotations +import asyncio import os import tempfile import pytest +def _drain_scheduler_tasks(scheduler) -> None: + """Cancel and await live scheduler tasks so none outlive the test. + + Uses the actual task handles rather than only clearing ``_tasks``: each + per-test event loop is created by ``asyncio.run``, so a task left behind by + a crashed/aborted test would otherwise keep its coroutine alive. We cancel + every live task and, when its loop is still usable, run it to completion to + let the cancellation propagate before dropping the reference. + """ + for task in list(scheduler._tasks.values()): + if task is None: + continue + loop = task.get_loop() + if task.done() or loop.is_closed(): + continue + task.cancel() + if not loop.is_running(): + try: + loop.run_until_complete(asyncio.gather(task, return_exceptions=True)) + except Exception: + pass + scheduler._tasks.clear() + + @pytest.fixture(scope="session", autouse=True) def _init_db(): import app.database.db as db from comfy.cli_args import args - db_path = tempfile.mktemp(suffix="-dlmgr-test.sqlite3") + fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3") + os.close(fd) args.database_url = f"sqlite:///{db_path}" db.init_db() yield @@ -36,11 +62,12 @@ def _reset_runtime(): from app.model_downloader.scheduler import SCHEDULER ns._session = None + _drain_scheduler_tasks(SCHEDULER) SCHEDULER._jobs.clear() - SCHEDULER._tasks.clear() SCHEDULER._backoff_until.clear() SCHEDULER._started = False yield + _drain_scheduler_tasks(SCHEDULER) ns._session = None diff --git a/tests-unit/model_downloader_test/test_planner_structural.py b/tests-unit/model_downloader_test/test_planner_structural.py index 33d4978af..139cf4094 100644 --- a/tests-unit/model_downloader_test/test_planner_structural.py +++ b/tests-unit/model_downloader_test/test_planner_structural.py @@ -69,3 +69,13 @@ def test_structural_skips_unknown_extension(tmp_path): p = tmp_path / "weights.bin" p.write_bytes(b"anything") structural.validate(str(p)) # no structural check, no raise + + +def test_structural_detects_truncation_via_name_hint(tmp_path): + # The downloader validates the opaque temp file (a ``.part`` path) but keys + # the format check off the final destination name via ``name_hint``, so + # truncation must still be detected instead of silently skipped. + p = tmp_path / "bad.comfy-download.part" + p.write_bytes(_make_safetensors(256, corrupt_size=True)) + with pytest.raises(structural.StructuralError): + structural.validate(str(p), name_hint="model.safetensors")