mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Fix more AI detected issues.,
This commit is contained in:
parent
3eb36377a8
commit
28b41d4d6d
@ -141,6 +141,13 @@ class Scheduler:
|
||||
status = await job.run()
|
||||
except Exception as e: # run() is defensive, but never let a task die silently
|
||||
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
|
||||
queries.update_download(
|
||||
download_id,
|
||||
status=DownloadStatus.FAILED,
|
||||
error=f"internal error: {e}",
|
||||
)
|
||||
if self._notify_cb:
|
||||
self._notify_cb(download_id)
|
||||
finally:
|
||||
self._tasks.pop(download_id, None)
|
||||
self._jobs.pop(download_id, None)
|
||||
|
||||
@ -44,6 +44,17 @@ def _positive_int(value: str) -> int:
|
||||
return ivalue
|
||||
|
||||
|
||||
def _non_negative_int(value: str) -> int:
|
||||
"""argparse type that rejects negatives but allows zero (a disable sentinel)."""
|
||||
try:
|
||||
ivalue = int(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)")
|
||||
return ivalue
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
@ -259,7 +270,7 @@ parser.add_argument("--download-segments", type=_positive_int, default=8, metava
|
||||
parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
||||
parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
||||
parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
||||
parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
|
||||
parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
|
||||
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
||||
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
||||
|
||||
|
||||
@ -8,18 +8,44 @@ session bound to its own event loop.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _drain_scheduler_tasks(scheduler) -> None:
|
||||
"""Cancel and await live scheduler tasks so none outlive the test.
|
||||
|
||||
Uses the actual task handles rather than only clearing ``_tasks``: each
|
||||
per-test event loop is created by ``asyncio.run``, so a task left behind by
|
||||
a crashed/aborted test would otherwise keep its coroutine alive. We cancel
|
||||
every live task and, when its loop is still usable, run it to completion to
|
||||
let the cancellation propagate before dropping the reference.
|
||||
"""
|
||||
for task in list(scheduler._tasks.values()):
|
||||
if task is None:
|
||||
continue
|
||||
loop = task.get_loop()
|
||||
if task.done() or loop.is_closed():
|
||||
continue
|
||||
task.cancel()
|
||||
if not loop.is_running():
|
||||
try:
|
||||
loop.run_until_complete(asyncio.gather(task, return_exceptions=True))
|
||||
except Exception:
|
||||
pass
|
||||
scheduler._tasks.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db():
|
||||
import app.database.db as db
|
||||
from comfy.cli_args import args
|
||||
|
||||
db_path = tempfile.mktemp(suffix="-dlmgr-test.sqlite3")
|
||||
fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3")
|
||||
os.close(fd)
|
||||
args.database_url = f"sqlite:///{db_path}"
|
||||
db.init_db()
|
||||
yield
|
||||
@ -36,11 +62,12 @@ def _reset_runtime():
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
|
||||
ns._session = None
|
||||
_drain_scheduler_tasks(SCHEDULER)
|
||||
SCHEDULER._jobs.clear()
|
||||
SCHEDULER._tasks.clear()
|
||||
SCHEDULER._backoff_until.clear()
|
||||
SCHEDULER._started = False
|
||||
yield
|
||||
_drain_scheduler_tasks(SCHEDULER)
|
||||
ns._session = None
|
||||
|
||||
|
||||
|
||||
@ -69,3 +69,13 @@ def test_structural_skips_unknown_extension(tmp_path):
|
||||
p = tmp_path / "weights.bin"
|
||||
p.write_bytes(b"anything")
|
||||
structural.validate(str(p)) # no structural check, no raise
|
||||
|
||||
|
||||
def test_structural_detects_truncation_via_name_hint(tmp_path):
|
||||
# The downloader validates the opaque temp file (a ``.part`` path) but keys
|
||||
# the format check off the final destination name via ``name_hint``, so
|
||||
# truncation must still be detected instead of silently skipped.
|
||||
p = tmp_path / "bad.comfy-download.part"
|
||||
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
||||
with pytest.raises(structural.StructuralError):
|
||||
structural.validate(str(p), name_hint="model.safetensors")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user