Fix more AI detected issues.,

This commit is contained in:
Talmaj Marinc 2026-07-01 11:38:33 +02:00
parent 3eb36377a8
commit 28b41d4d6d
4 changed files with 58 additions and 3 deletions

View File

@ -141,6 +141,13 @@ class Scheduler:
status = await job.run()
except Exception as e: # run() is defensive, but never let a task die silently
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
queries.update_download(
download_id,
status=DownloadStatus.FAILED,
error=f"internal error: {e}",
)
if self._notify_cb:
self._notify_cb(download_id)
finally:
self._tasks.pop(download_id, None)
self._jobs.pop(download_id, None)

View File

@ -44,6 +44,17 @@ def _positive_int(value: str) -> int:
return ivalue
def _non_negative_int(value: str) -> int:
"""argparse type that rejects negatives but allows zero (a disable sentinel)."""
try:
ivalue = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
if ivalue < 0:
raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)")
return ivalue
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
@ -259,7 +270,7 @@ parser.add_argument("--download-segments", type=_positive_int, default=8, metava
parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")

View File

@ -8,18 +8,44 @@ session bound to its own event loop.
from __future__ import annotations
import asyncio
import os
import tempfile
import pytest
def _drain_scheduler_tasks(scheduler) -> None:
"""Cancel and await live scheduler tasks so none outlive the test.
Uses the actual task handles rather than only clearing ``_tasks``: each
per-test event loop is created by ``asyncio.run``, so a task left behind by
a crashed/aborted test would otherwise keep its coroutine alive. We cancel
every live task and, when its loop is still usable, run it to completion to
let the cancellation propagate before dropping the reference.
"""
for task in list(scheduler._tasks.values()):
if task is None:
continue
loop = task.get_loop()
if task.done() or loop.is_closed():
continue
task.cancel()
if not loop.is_running():
try:
loop.run_until_complete(asyncio.gather(task, return_exceptions=True))
except Exception:
pass
scheduler._tasks.clear()
@pytest.fixture(scope="session", autouse=True)
def _init_db():
import app.database.db as db
from comfy.cli_args import args
db_path = tempfile.mktemp(suffix="-dlmgr-test.sqlite3")
fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3")
os.close(fd)
args.database_url = f"sqlite:///{db_path}"
db.init_db()
yield
@ -36,11 +62,12 @@ def _reset_runtime():
from app.model_downloader.scheduler import SCHEDULER
ns._session = None
_drain_scheduler_tasks(SCHEDULER)
SCHEDULER._jobs.clear()
SCHEDULER._tasks.clear()
SCHEDULER._backoff_until.clear()
SCHEDULER._started = False
yield
_drain_scheduler_tasks(SCHEDULER)
ns._session = None

View File

@ -69,3 +69,13 @@ def test_structural_skips_unknown_extension(tmp_path):
p = tmp_path / "weights.bin"
p.write_bytes(b"anything")
structural.validate(str(p)) # no structural check, no raise
def test_structural_detects_truncation_via_name_hint(tmp_path):
# The downloader validates the opaque temp file (a ``.part`` path) but keys
# the format check off the final destination name via ``name_hint``, so
# truncation must still be detected instead of silently skipped.
p = tmp_path / "bad.comfy-download.part"
p.write_bytes(_make_safetensors(256, corrupt_size=True))
with pytest.raises(structural.StructuralError):
structural.validate(str(p), name_hint="model.safetensors")