From a1233b1319f8ca1348e283e7b907bb27de01e41b Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 17 Mar 2026 19:07:04 -0700 Subject: [PATCH] Fix shared-asset overwrite corruption, stale enrichment race, and path validation - Detach ref to new stub asset on overwrite when siblings share the asset - Add optimistic mtime_ns guard in enrich_asset to discard stale results - Normalize and validate output paths stay under output root, deduplicate - Skip metadata extraction for stub-only registration (align with fast scan) - Add RLock comment explaining re-entrant drain requirement - Log warning when pending enrich drain fails to start - Add create_stub_asset and count_active_siblings query functions Amp-Thread-ID: https://ampcode.com/threads/T-019cfe06-f0dc-776f-81ad-e9f3d71be597 Co-authored-by: Amp --- app/assets/database/queries/__init__.py | 4 + app/assets/database/queries/asset.py | 22 +- .../database/queries/asset_reference.py | 44 +++- app/assets/scanner.py | 25 +- app/assets/seeder.py | 10 +- app/assets/services/ingest.py | 45 ++-- main.py | 245 +++++++++++++----- 7 files changed, 289 insertions(+), 106 deletions(-) diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index 1632937b2..9949e84e1 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -1,6 +1,7 @@ from app.assets.database.queries.asset import ( asset_exists_by_hash, bulk_insert_assets, + create_stub_asset, get_asset_by_hash, get_existing_asset_ids, reassign_asset_references, @@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import ( UnenrichedReferenceRow, bulk_insert_references_ignore_conflicts, bulk_update_enrichment_level, + count_active_siblings, bulk_update_is_missing, bulk_update_needs_verify, convert_metadata_to_rows, @@ -80,6 +82,8 @@ __all__ = [ "bulk_insert_references_ignore_conflicts", "bulk_insert_tags_and_meta", "bulk_update_enrichment_level", + "count_active_siblings", + "create_stub_asset", "bulk_update_is_missing", "bulk_update_needs_verify", "convert_metadata_to_rows", diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py index 594d1f1b2..e57574533 100644 --- a/app/assets/database/queries/asset.py +++ b/app/assets/database/queries/asset.py @@ -4,7 +4,11 @@ from sqlalchemy.dialects import sqlite from sqlalchemy.orm import Session from app.assets.database.models import Asset, AssetReference -from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + calculate_rows_per_statement, + iter_chunks, +) def asset_exists_by_hash( @@ -78,6 +82,18 @@ def upsert_asset( return asset, created, updated +def create_stub_asset( + session: Session, + size_bytes: int, + mime_type: str | None = None, +) -> Asset: + """Create a new asset with no hash (stub for later enrichment).""" + asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None) + session.add(asset) + session.flush() + return asset + + def bulk_insert_assets( session: Session, rows: list[dict], @@ -99,9 +115,7 @@ def get_existing_asset_ids( return set() found: set[str] = set() for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): - rows = session.execute( - select(Asset.id).where(Asset.id.in_(chunk)) - ).fetchall() + rows = session.execute(select(Asset.id).where(Asset.id.in_(chunk))).fetchall() found.update(row[0] for row in rows) return found diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 084a32512..931239523 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -66,14 +66,18 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]: if isinstance(value, list): if all(_check_is_scalar(x) for x in value): - return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None] - return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None] + return [ + _scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None + ] + return [ + {"key": key, "ordinal": i, "val_json": x} + for i, x in enumerate(value) + if x is not None + ] return [{"key": key, "ordinal": 0, "val_json": value}] - - def get_reference_by_id( session: Session, reference_id: str, @@ -114,6 +118,23 @@ def get_reference_by_file_path( ) +def count_active_siblings( + session: Session, + asset_id: str, + exclude_reference_id: str, +) -> int: + """Count active (non-deleted) references to an asset, excluding one reference.""" + return ( + session.query(AssetReference) + .filter( + AssetReference.asset_id == asset_id, + AssetReference.id != exclude_reference_id, + AssetReference.deleted_at.is_(None), + ) + .count() + ) + + def reference_exists_for_asset_id( session: Session, asset_id: str, @@ -643,8 +664,11 @@ def upsert_reference( ) ) .values( - asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, - deleted_at=None, updated_at=now, + asset_id=asset_id, + mtime_ns=int(mtime_ns), + is_missing=False, + deleted_at=None, + updated_at=now, ) ) res2 = session.execute(upd) @@ -834,9 +858,7 @@ def bulk_update_is_missing( return total -def update_is_missing_by_asset_id( - session: Session, asset_id: str, value: bool -) -> int: +def update_is_missing_by_asset_id(session: Session, asset_id: str, value: bool) -> int: """Set is_missing flag for ALL references belonging to an asset. Returns: Number of rows updated @@ -1003,9 +1025,7 @@ def get_references_by_paths_and_asset_ids( pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( chunk ) - result = session.execute( - select(AssetReference.file_path).where(pairwise) - ) + result = session.execute(select(AssetReference.file_path).where(pairwise)) winners.update(result.scalars().all()) return winners diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 8120e763e..4dc01d5c0 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -13,6 +13,7 @@ from app.assets.database.queries import ( delete_references_by_ids, ensure_tags_exist, get_asset_by_hash, + get_reference_by_id, get_references_for_prefixes, get_unenriched_references, mark_references_missing_outside_prefixes, @@ -346,7 +347,6 @@ def build_asset_specs( return specs, tag_pool, skipped - def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: """Insert asset specs into database, returning count of created refs.""" if not specs: @@ -427,6 +427,7 @@ def enrich_asset( except OSError: return new_level + initial_mtime_ns = get_mtime_ns(stat_p) rel_fname = compute_relative_filename(file_path) mime_type: str | None = None metadata = None @@ -453,8 +454,10 @@ def enrich_asset( checkpoint = hash_checkpoints.get(file_path) if checkpoint is not None: cur_stat = os.stat(file_path, follow_symlinks=True) - if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) - or checkpoint.file_size != cur_stat.st_size): + if ( + checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size + ): checkpoint = None hash_checkpoints.pop(file_path, None) else: @@ -481,7 +484,9 @@ def enrich_asset( stat_after = os.stat(file_path, follow_symlinks=True) mtime_after = get_mtime_ns(stat_after) if mtime_before != mtime_after: - logging.warning("File modified during hashing, discarding hash: %s", file_path) + logging.warning( + "File modified during hashing, discarding hash: %s", file_path + ) else: full_hash = f"blake3:{digest}" metadata_ok = not extract_metadata or metadata is not None @@ -490,6 +495,18 @@ def enrich_asset( except Exception as e: logging.warning("Failed to hash %s: %s", file_path, e) + # Optimistic guard: if the reference's mtime_ns changed since we + # started (e.g. ingest_existing_file updated it), our results are + # stale — discard them to avoid overwriting fresh registration data. + ref = get_reference_by_id(session, reference_id) + if ref is None or ref.mtime_ns != initial_mtime_ns: + session.rollback() + logging.info( + "Ref %s mtime changed during enrichment, discarding stale result", + reference_id, + ) + return ENRICHMENT_STUB + if extract_metadata and metadata: system_metadata = metadata.to_user_metadata() set_reference_system_metadata(session, reference_id, system_metadata) diff --git a/app/assets/seeder.py b/app/assets/seeder.py index d65ffb23f..1ae2d3149 100644 --- a/app/assets/seeder.py +++ b/app/assets/seeder.py @@ -77,6 +77,8 @@ class _AssetSeeder: """ def __init__(self) -> None: + # RLock is required because _run_scan() drains pending work while + # holding _lock and re-enters start() which also acquires _lock. self._lock = threading.RLock() self._state = State.IDLE self._progress: Progress | None = None @@ -639,10 +641,14 @@ class _AssetSeeder: pending = self._pending_enrich if pending is not None: self._pending_enrich = None - self.start_enrich( + if not self.start_enrich( roots=pending["roots"], compute_hashes=pending["compute_hashes"], - ) + ): + logging.warning( + "Pending enrich scan could not start (roots=%s)", + pending["roots"], + ) def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: """Run phase 1: fast scan to create stub records. diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index dc53276dc..7899b68cf 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -9,6 +9,8 @@ from sqlalchemy.orm import Session import app.assets.services.hashing as hashing from app.assets.database.queries import ( add_tags_to_reference, + count_active_siblings, + create_stub_asset, fetch_reference_and_asset, get_asset_by_hash, get_reference_by_file_path, @@ -26,7 +28,6 @@ from app.assets.database.queries import ( from app.assets.helpers import get_utc_now, normalize_tags from app.assets.services.bulk_ingest import batch_insert_seed_assets from app.assets.services.file_utils import get_size_and_mtime_ns -from app.assets.services.metadata_extract import extract_file_metadata from app.assets.services.path_utils import ( compute_relative_filename, get_name_and_tags_from_asset_path, @@ -146,7 +147,9 @@ def register_output_files( if not os.path.isfile(abs_path): continue try: - if ingest_existing_file(abs_path, user_metadata=user_metadata, job_id=job_id): + if ingest_existing_file( + abs_path, user_metadata=user_metadata, job_id=job_id + ): registered += 1 except Exception: logging.exception("Failed to register output: %s", abs_path) @@ -185,19 +188,28 @@ def ingest_existing_file( existing_ref.is_missing = False existing_ref.deleted_at = None existing_ref.updated_at = now - # Reset enrichment so the enricher re-hashes existing_ref.enrichment_level = 0 - # Clear the asset hash so enrich recomputes it + asset = existing_ref.asset if asset: - asset.hash = None - asset.size_bytes = size_bytes - if mime_type: - asset.mime_type = mime_type + # If other refs share this asset, detach to a new stub + # instead of mutating the shared row. + siblings = count_active_siblings(session, asset.id, existing_ref.id) + if siblings > 0: + new_asset = create_stub_asset( + session, + size_bytes=size_bytes, + mime_type=mime_type or asset.mime_type, + ) + existing_ref.asset_id = new_asset.id + else: + asset.hash = None + asset.size_bytes = size_bytes + if mime_type: + asset.mime_type = mime_type session.commit() return True - metadata = extract_file_metadata(locator) spec = { "abs_path": abs_path, "size_bytes": size_bytes, @@ -205,9 +217,9 @@ def ingest_existing_file( "info_name": name, "tags": tags, "fname": os.path.basename(abs_path), - "metadata": metadata, + "metadata": None, "hash": None, - "mime_type": mime_type or metadata.content_type, + "mime_type": mime_type, "job_id": job_id, } result = batch_insert_seed_assets(session, [spec], owner_id=owner_id) @@ -262,7 +274,9 @@ def _register_existing_asset( return result new_meta = dict(user_metadata) - computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + computed_filename = ( + compute_relative_filename(ref.file_path) if ref.file_path else None + ) if computed_filename: new_meta["filename"] = computed_filename @@ -294,7 +308,6 @@ def _register_existing_asset( return result - def _update_metadata_with_filename( session: Session, reference_id: str, @@ -475,8 +488,7 @@ def register_file_in_place( size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path) content_type = mime_type or ( - mimetypes.guess_type(abs_path, strict=False)[0] - or "application/octet-stream" + mimetypes.guess_type(abs_path, strict=False)[0] or "application/octet-stream" ) ingest_result = _ingest_file_from_path( @@ -527,7 +539,8 @@ def create_from_hash( result = _register_existing_asset( asset_hash=canonical, name=_sanitize_filename( - name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + name, + fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical, ), user_metadata=user_metadata or {}, tags=tags or [], diff --git a/main.py b/main.py index da4983cd4..622ea1be3 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import comfy.options + comfy.options.enable_args_parsing() import os @@ -23,9 +24,9 @@ from comfy_api import feature_flags from app.database.db import init_db, dependencies_available if __name__ == "__main__": - #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. - os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' - os.environ['DO_NOT_TRACK'] = '1' + # NOTE: These do not do anything on core ComfyUI, they are for custom nodes. + os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" + os.environ["DO_NOT_TRACK"] = "1" setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) @@ -37,40 +38,46 @@ if enables_dynamic_vram(): comfy_aimdo.control.init() if os.name == "nt": - os.environ['MIMALLOC_PURGE_DELAY'] = '0' + os.environ["MIMALLOC_PURGE_DELAY"] = "0" if __name__ == "__main__": - os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' + os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1" if args.default_device is not None: default_dev = args.default_device devices = list(range(32)) devices.remove(default_dev) devices.insert(0, default_dev) - devices = ','.join(map(str, devices)) - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) - os.environ['HIP_VISIBLE_DEVICES'] = str(devices) + devices = ",".join(map(str, devices)) + os.environ["CUDA_VISIBLE_DEVICES"] = str(devices) + os.environ["HIP_VISIBLE_DEVICES"] = str(devices) if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device) + os.environ["HIP_VISIBLE_DEVICES"] = str(args.cuda_device) os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.oneapi_device_selector is not None: - os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector - logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) + os.environ["ONEAPI_DEVICE_SELECTOR"] = args.oneapi_device_selector + logging.info( + "Set oneapi device selector to: {}".format(args.oneapi_device_selector) + ) if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): - os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + os.environ["OCL_SET_SVM_SIZE"] = "262144" # set at the request of AMD def handle_comfyui_manager_unavailable(): - manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + manager_req_path = os.path.join( + os.path.dirname(os.path.abspath(folder_paths.__file__)), + "manager_requirements.txt", + ) uv_available = shutil.which("uv") is not None pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" @@ -86,7 +93,9 @@ if args.enable_manager: if importlib.util.find_spec("comfyui_manager"): import comfyui_manager - if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): + if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith( + "__init__.py" + ): handle_comfyui_manager_unavailable() else: handle_comfyui_manager_unavailable() @@ -94,7 +103,9 @@ if args.enable_manager: def apply_custom_paths(): # extra model paths - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + extra_model_paths_config_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml" + ) if os.path.isfile(extra_model_paths_config_path): utils.extra_config.load_extra_path_config(extra_model_paths_config_path) @@ -109,12 +120,22 @@ def apply_custom_paths(): folder_paths.set_output_directory(output_dir) # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes - folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) - folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) - folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) - folder_paths.add_model_folder_path("diffusion_models", - os.path.join(folder_paths.get_output_directory(), "diffusion_models")) - folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) + folder_paths.add_model_folder_path( + "checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints") + ) + folder_paths.add_model_folder_path( + "clip", os.path.join(folder_paths.get_output_directory(), "clip") + ) + folder_paths.add_model_folder_path( + "vae", os.path.join(folder_paths.get_output_directory(), "vae") + ) + folder_paths.add_model_folder_path( + "diffusion_models", + os.path.join(folder_paths.get_output_directory(), "diffusion_models"), + ) + folder_paths.add_model_folder_path( + "loras", os.path.join(folder_paths.get_output_directory(), "loras") + ) if args.input_directory: input_dir = os.path.abspath(args.input_directory) @@ -154,17 +175,28 @@ def execute_prestartup_script(): if comfyui_manager.should_be_disabled(module_path): continue - if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": + if ( + os.path.isfile(module_path) + or module_path.endswith(".disabled") + or module_path == "__pycache__" + ): continue script_path = os.path.join(module_path, "prestartup_script.py") if os.path.exists(script_path): - if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: - logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") + if ( + args.disable_all_custom_nodes + and possible_module not in args.whitelist_custom_nodes + ): + logging.info( + f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes" + ) continue time_before = time.perf_counter() success = execute_script(script_path) - node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) + node_prestartup_times.append( + (time.perf_counter() - time_before, module_path, success) + ) if len(node_prestartup_times) > 0: logging.info("\nPrestartup times for custom nodes:") for n in sorted(node_prestartup_times): @@ -175,6 +207,7 @@ def execute_prestartup_script(): logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) logging.info("") + apply_custom_paths() init_mime_types() @@ -189,8 +222,10 @@ import asyncio import threading import gc -if 'torch' in sys.modules: - logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") +if "torch" in sys.modules: + logging.warning( + "WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point." + ) import comfy.utils @@ -207,26 +242,38 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): - if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): - logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): - if args.verbose == 'DEBUG': +if args.enable_dynamic_vram or ( + enables_dynamic_vram() + and comfy.model_management.is_nvidia() + and not comfy.model_management.is_wsl() +): + if (not args.enable_dynamic_vram) and ( + comfy.model_management.torch_version_numeric < (2, 8) + ): + logging.warning( + "Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows" + ) + elif comfy_aimdo.control.init_device( + comfy.model_management.get_torch_device().index + ): + if args.verbose == "DEBUG": comfy_aimdo.control.set_log_debug() - elif args.verbose == 'CRITICAL': + elif args.verbose == "CRITICAL": comfy_aimdo.control.set_log_critical() - elif args.verbose == 'ERROR': + elif args.verbose == "ERROR": comfy_aimdo.control.set_log_error() - elif args.verbose == 'WARNING': + elif args.verbose == "WARNING": comfy_aimdo.control.set_log_warning() - else: #INFO + else: # INFO comfy_aimdo.control.set_log_info() comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic comfy.memory_management.aimdo_enabled = True logging.info("DynamicVRAM support detected and enabled") else: - logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + logging.warning( + "No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows" + ) def cuda_malloc_warning(): @@ -238,15 +285,19 @@ def cuda_malloc_warning(): if b in device_name: cuda_malloc_warning = True if cuda_malloc_warning: - logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") + logging.warning( + '\nWARNING: this card most likely does not support cuda-malloc, if you get "CUDA error" please run ComfyUI with: --disable-cuda-malloc\n' + ) def _collect_output_absolute_paths(history_result: dict) -> list[str]: """Extract absolute file paths for output items from a history result.""" - paths = [] + paths: list[str] = [] base_dir = folder_paths.get_directory_by_type("output") if base_dir is None: return paths + base_dir = os.path.abspath(base_dir) + seen: set[str] = set() for node_output in history_result.get("outputs", {}).values(): for items in node_output.values(): if not isinstance(items, list): @@ -257,7 +308,14 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: filename = item.get("filename") if not filename: continue - paths.append(os.path.join(base_dir, item.get("subfolder", ""), filename)) + abs_path = os.path.abspath( + os.path.join(base_dir, item.get("subfolder", ""), filename) + ) + if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir: + continue + if abs_path not in seen: + seen.add(abs_path) + paths.append(abs_path) return paths @@ -271,7 +329,11 @@ def prompt_worker(q, server_instance): elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) + e = execution.PromptExecutor( + server_instance, + cache_type=cache_type, + cache_args={"lru": args.cache_lru, "ram": args.cache_ram}, + ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -299,14 +361,22 @@ def prompt_worker(q, server_instance): need_gc = True remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done(item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str='success' if e.success else 'error', - completed=e.success, - messages=e.status_messages), process_item=remove_sensitive) + q.task_done( + item_id, + e.history_result, + status=execution.PromptQueue.ExecutionStatus( + status_str="success" if e.success else "error", + completed=e.success, + messages=e.status_messages, + ), + process_item=remove_sensitive, + ) if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + server_instance.send_sync( + "executing", + {"node": None, "prompt_id": prompt_id}, + server_instance.client_id, + ) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -349,14 +419,16 @@ def prompt_worker(q, server_instance): asset_seeder.resume() -async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): +async def run(server_instance, address="", port=8188, verbose=True, call_on_start=None): addresses = [] for addr in address.split(","): addresses.append((addr, port)) await asyncio.gather( - server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() + server_instance.start_multi_address(addresses, call_on_start, verbose), + server_instance.publish_loop(), ) + def hijack_progress(server_instance): def hook(value, total, preview_image, prompt_id=None, node_id=None): executing_context = get_executing_context() @@ -369,7 +441,12 @@ def hijack_progress(server_instance): prompt_id = server_instance.last_prompt_id if node_id is None: node_id = server_instance.last_node_id - progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + progress = { + "value": value, + "max": total, + "prompt_id": prompt_id, + "node": node_id, + } get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id) @@ -400,8 +477,14 @@ def setup_database(): if dependencies_available(): init_db() if args.enable_assets: - if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): - logging.info("Background asset scan initiated for models, input, output") + if asset_seeder.start( + roots=("models", "input", "output"), + prune_first=True, + compute_hashes=True, + ): + logging.info( + "Background asset scan initiated for models, input, output" + ) except Exception as e: if "database is locked" in str(e): logging.error( @@ -420,7 +503,9 @@ def setup_database(): " 3. Use an in-memory database: --database-url sqlite:///:memory:" ) sys.exit(1) - logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + logging.error( + f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}" + ) def start_comfyui(asyncio_loop=None): @@ -437,6 +522,7 @@ def start_comfyui(asyncio_loop=None): if args.windows_standalone_build: try: import new_updater + new_updater.update_windows_updater() except: pass @@ -450,10 +536,13 @@ def start_comfyui(asyncio_loop=None): comfyui_manager.start() hook_breaker_ac10a0.save_functions() - asyncio_loop.run_until_complete(nodes.init_extra_nodes( - init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, - init_api_nodes=not args.disable_api_nodes - )) + asyncio_loop.run_until_complete( + nodes.init_extra_nodes( + init_custom_nodes=(not args.disable_all_custom_nodes) + or len(args.whitelist_custom_nodes) > 0, + init_api_nodes=not args.disable_api_nodes, + ) + ) hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() @@ -462,7 +551,14 @@ def start_comfyui(asyncio_loop=None): prompt_server.add_routes() hijack_progress(prompt_server) - threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() + threading.Thread( + target=prompt_worker, + daemon=True, + args=( + prompt_server.prompt_queue, + prompt_server, + ), + ).start() if args.quick_test_for_ci: exit(0) @@ -470,18 +566,27 @@ def start_comfyui(asyncio_loop=None): os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) call_on_start = None if args.auto_launch: + def startup_server(scheme, address, port): import webbrowser - if os.name == 'nt' and address == '0.0.0.0': - address = '127.0.0.1' - if ':' in address: + + if os.name == "nt" and address == "0.0.0.0": + address = "127.0.0.1" + if ":" in address: address = "[{}]".format(address) webbrowser.open(f"{scheme}://{address}:{port}") + call_on_start = startup_server async def start_all(): await prompt_server.setup() - await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) + await run( + prompt_server, + address=args.listen, + port=args.port, + verbose=not args.dont_print_server, + call_on_start=call_on_start, + ) # Returning these so that other code can integrate with the ComfyUI loop and server return asyncio_loop, prompt_server, start_all @@ -493,12 +598,16 @@ if __name__ == "__main__": logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) for package in ("comfy-aimdo", "comfy-kitchen"): try: - logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + logging.info( + "{} version: {}".format(package, importlib.metadata.version(package)) + ) except: pass if sys.version_info.major == 3 and sys.version_info.minor < 10: - logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + logging.warning( + "WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended." + ) event_loop, _, start_all_func = start_comfyui() try: