diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py index e57574533..cc7168431 100644 --- a/app/assets/database/queries/asset.py +++ b/app/assets/database/queries/asset.py @@ -4,11 +4,7 @@ from sqlalchemy.dialects import sqlite from sqlalchemy.orm import Session from app.assets.database.models import Asset, AssetReference -from app.assets.database.queries.common import ( - MAX_BIND_PARAMS, - calculate_rows_per_statement, - iter_chunks, -) +from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks def asset_exists_by_hash( @@ -115,7 +111,9 @@ def get_existing_asset_ids( return set() found: set[str] = set() for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): - rows = session.execute(select(Asset.id).where(Asset.id.in_(chunk))).fetchall() + rows = session.execute( + select(Asset.id).where(Asset.id.in_(chunk)) + ).fetchall() found.update(row[0] for row in rows) return found diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 931239523..8b90ae511 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -66,18 +66,14 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]: if isinstance(value, list): if all(_check_is_scalar(x) for x in value): - return [ - _scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None - ] - return [ - {"key": key, "ordinal": i, "val_json": x} - for i, x in enumerate(value) - if x is not None - ] + return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None] return [{"key": key, "ordinal": 0, "val_json": value}] + + def get_reference_by_id( session: Session, reference_id: str, @@ -664,11 +660,8 @@ def upsert_reference( ) ) .values( - asset_id=asset_id, - mtime_ns=int(mtime_ns), - is_missing=False, - deleted_at=None, - updated_at=now, + asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, + deleted_at=None, updated_at=now, ) ) res2 = session.execute(upd) @@ -858,7 +851,9 @@ def bulk_update_is_missing( return total -def update_is_missing_by_asset_id(session: Session, asset_id: str, value: bool) -> int: +def update_is_missing_by_asset_id( + session: Session, asset_id: str, value: bool +) -> int: """Set is_missing flag for ALL references belonging to an asset. Returns: Number of rows updated @@ -1025,7 +1020,9 @@ def get_references_by_paths_and_asset_ids( pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( chunk ) - result = session.execute(select(AssetReference.file_path).where(pairwise)) + result = session.execute( + select(AssetReference.file_path).where(pairwise) + ) winners.update(result.scalars().all()) return winners diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 4dc01d5c0..ebb6869af 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -347,6 +347,7 @@ def build_asset_specs( return specs, tag_pool, skipped + def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: """Insert asset specs into database, returning count of created refs.""" if not specs: @@ -454,10 +455,8 @@ def enrich_asset( checkpoint = hash_checkpoints.get(file_path) if checkpoint is not None: cur_stat = os.stat(file_path, follow_symlinks=True) - if ( - checkpoint.mtime_ns != get_mtime_ns(cur_stat) - or checkpoint.file_size != cur_stat.st_size - ): + if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size): checkpoint = None hash_checkpoints.pop(file_path, None) else: @@ -484,9 +483,7 @@ def enrich_asset( stat_after = os.stat(file_path, follow_symlinks=True) mtime_after = get_mtime_ns(stat_after) if mtime_before != mtime_after: - logging.warning( - "File modified during hashing, discarding hash: %s", file_path - ) + logging.warning("File modified during hashing, discarding hash: %s", file_path) else: full_hash = f"blake3:{digest}" metadata_ok = not extract_metadata or metadata is not None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index fe97691bb..f0b070517 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -277,9 +277,7 @@ def _register_existing_asset( return result new_meta = dict(user_metadata) - computed_filename = ( - compute_relative_filename(ref.file_path) if ref.file_path else None - ) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None if computed_filename: new_meta["filename"] = computed_filename @@ -311,6 +309,7 @@ def _register_existing_asset( return result + def _update_metadata_with_filename( session: Session, reference_id: str, @@ -491,7 +490,8 @@ def register_file_in_place( size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path) content_type = mime_type or ( - mimetypes.guess_type(abs_path, strict=False)[0] or "application/octet-stream" + mimetypes.guess_type(abs_path, strict=False)[0] + or "application/octet-stream" ) ingest_result = _ingest_file_from_path( @@ -542,8 +542,7 @@ def create_from_hash( result = _register_existing_asset( asset_hash=canonical, name=_sanitize_filename( - name, - fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical, + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical ), user_metadata=user_metadata or {}, tags=tags or [], diff --git a/main.py b/main.py index b353c8f5b..f5a0639e8 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,4 @@ import comfy.options - comfy.options.enable_args_parsing() import os @@ -24,9 +23,9 @@ from comfy_api import feature_flags from app.database.db import init_db, dependencies_available if __name__ == "__main__": - # NOTE: These do not do anything on core ComfyUI, they are for custom nodes. - os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" - os.environ["DO_NOT_TRACK"] = "1" + #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. + os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' + os.environ['DO_NOT_TRACK'] = '1' setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) @@ -38,46 +37,40 @@ if enables_dynamic_vram(): comfy_aimdo.control.init() if os.name == "nt": - os.environ["MIMALLOC_PURGE_DELAY"] = "0" + os.environ['MIMALLOC_PURGE_DELAY'] = '0' if __name__ == "__main__": - os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1" + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' if args.default_device is not None: default_dev = args.default_device devices = list(range(32)) devices.remove(default_dev) devices.insert(0, default_dev) - devices = ",".join(map(str, devices)) - os.environ["CUDA_VISIBLE_DEVICES"] = str(devices) - os.environ["HIP_VISIBLE_DEVICES"] = str(devices) + devices = ','.join(map(str, devices)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) + os.environ['HIP_VISIBLE_DEVICES'] = str(devices) if args.cuda_device is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device) - os.environ["HIP_VISIBLE_DEVICES"] = str(args.cuda_device) + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.oneapi_device_selector is not None: - os.environ["ONEAPI_DEVICE_SELECTOR"] = args.oneapi_device_selector - logging.info( - "Set oneapi device selector to: {}".format(args.oneapi_device_selector) - ) + os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector + logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) if args.deterministic: - if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" import cuda_malloc - if "rocm" in cuda_malloc.get_torch_version_noimport(): - os.environ["OCL_SET_SVM_SIZE"] = "262144" # set at the request of AMD + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD def handle_comfyui_manager_unavailable(): - manager_req_path = os.path.join( - os.path.dirname(os.path.abspath(folder_paths.__file__)), - "manager_requirements.txt", - ) + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") uv_available = shutil.which("uv") is not None pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" @@ -93,9 +86,7 @@ if args.enable_manager: if importlib.util.find_spec("comfyui_manager"): import comfyui_manager - if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith( - "__init__.py" - ): + if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): handle_comfyui_manager_unavailable() else: handle_comfyui_manager_unavailable() @@ -103,9 +94,7 @@ if args.enable_manager: def apply_custom_paths(): # extra model paths - extra_model_paths_config_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml" - ) + extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): utils.extra_config.load_extra_path_config(extra_model_paths_config_path) @@ -120,22 +109,12 @@ def apply_custom_paths(): folder_paths.set_output_directory(output_dir) # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes - folder_paths.add_model_folder_path( - "checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints") - ) - folder_paths.add_model_folder_path( - "clip", os.path.join(folder_paths.get_output_directory(), "clip") - ) - folder_paths.add_model_folder_path( - "vae", os.path.join(folder_paths.get_output_directory(), "vae") - ) - folder_paths.add_model_folder_path( - "diffusion_models", - os.path.join(folder_paths.get_output_directory(), "diffusion_models"), - ) - folder_paths.add_model_folder_path( - "loras", os.path.join(folder_paths.get_output_directory(), "loras") - ) + folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) + folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) + folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) + folder_paths.add_model_folder_path("diffusion_models", + os.path.join(folder_paths.get_output_directory(), "diffusion_models")) + folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) if args.input_directory: input_dir = os.path.abspath(args.input_directory) @@ -175,28 +154,17 @@ def execute_prestartup_script(): if comfyui_manager.should_be_disabled(module_path): continue - if ( - os.path.isfile(module_path) - or module_path.endswith(".disabled") - or module_path == "__pycache__" - ): + if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": continue script_path = os.path.join(module_path, "prestartup_script.py") if os.path.exists(script_path): - if ( - args.disable_all_custom_nodes - and possible_module not in args.whitelist_custom_nodes - ): - logging.info( - f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes" - ) + if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: + logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue time_before = time.perf_counter() success = execute_script(script_path) - node_prestartup_times.append( - (time.perf_counter() - time_before, module_path, success) - ) + node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_prestartup_times) > 0: logging.info("\nPrestartup times for custom nodes:") for n in sorted(node_prestartup_times): @@ -207,7 +175,6 @@ def execute_prestartup_script(): logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) logging.info("") - apply_custom_paths() init_mime_types() @@ -222,10 +189,8 @@ import asyncio import threading import gc -if "torch" in sys.modules: - logging.warning( - "WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point." - ) +if 'torch' in sys.modules: + logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") import comfy.utils @@ -242,38 +207,26 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if args.enable_dynamic_vram or ( - enables_dynamic_vram() - and comfy.model_management.is_nvidia() - and not comfy.model_management.is_wsl() -): - if (not args.enable_dynamic_vram) and ( - comfy.model_management.torch_version_numeric < (2, 8) - ): - logging.warning( - "Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows" - ) - elif comfy_aimdo.control.init_device( - comfy.model_management.get_torch_device().index - ): - if args.verbose == "DEBUG": +if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): + if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): + logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() - elif args.verbose == "CRITICAL": + elif args.verbose == 'CRITICAL': comfy_aimdo.control.set_log_critical() - elif args.verbose == "ERROR": + elif args.verbose == 'ERROR': comfy_aimdo.control.set_log_error() - elif args.verbose == "WARNING": + elif args.verbose == 'WARNING': comfy_aimdo.control.set_log_warning() - else: # INFO + else: #INFO comfy_aimdo.control.set_log_info() comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic comfy.memory_management.aimdo_enabled = True logging.info("DynamicVRAM support detected and enabled") else: - logging.warning( - "No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows" - ) + logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") def cuda_malloc_warning(): @@ -285,9 +238,7 @@ def cuda_malloc_warning(): if b in device_name: cuda_malloc_warning = True if cuda_malloc_warning: - logging.warning( - '\nWARNING: this card most likely does not support cuda-malloc, if you get "CUDA error" please run ComfyUI with: --disable-cuda-malloc\n' - ) + logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def _collect_output_absolute_paths(history_result: dict) -> list[str]: @@ -329,11 +280,7 @@ def prompt_worker(q, server_instance): elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor( - server_instance, - cache_type=cache_type, - cache_args={"lru": args.cache_lru, "ram": args.cache_ram}, - ) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -361,22 +308,14 @@ def prompt_worker(q, server_instance): need_gc = True remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done( - item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str="success" if e.success else "error", - completed=e.success, - messages=e.status_messages, - ), - process_item=remove_sensitive, - ) + q.task_done(item_id, + e.history_result, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: - server_instance.send_sync( - "executing", - {"node": None, "prompt_id": prompt_id}, - server_instance.client_id, - ) + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -419,16 +358,14 @@ def prompt_worker(q, server_instance): asset_seeder.resume() -async def run(server_instance, address="", port=8188, verbose=True, call_on_start=None): +async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): addresses = [] for addr in address.split(","): addresses.append((addr, port)) await asyncio.gather( - server_instance.start_multi_address(addresses, call_on_start, verbose), - server_instance.publish_loop(), + server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() ) - def hijack_progress(server_instance): def hook(value, total, preview_image, prompt_id=None, node_id=None): executing_context = get_executing_context() @@ -441,12 +378,7 @@ def hijack_progress(server_instance): prompt_id = server_instance.last_prompt_id if node_id is None: node_id = server_instance.last_node_id - progress = { - "value": value, - "max": total, - "prompt_id": prompt_id, - "node": node_id, - } + progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id) @@ -477,14 +409,8 @@ def setup_database(): if dependencies_available(): init_db() if args.enable_assets: - if asset_seeder.start( - roots=("models", "input", "output"), - prune_first=True, - compute_hashes=True, - ): - logging.info( - "Background asset scan initiated for models, input, output" - ) + if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): + logging.info("Background asset scan initiated for models, input, output") except Exception as e: if "database is locked" in str(e): logging.error( @@ -503,9 +429,7 @@ def setup_database(): " 3. Use an in-memory database: --database-url sqlite:///:memory:" ) sys.exit(1) - logging.error( - f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}" - ) + logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") def start_comfyui(asyncio_loop=None): @@ -522,7 +446,6 @@ def start_comfyui(asyncio_loop=None): if args.windows_standalone_build: try: import new_updater - new_updater.update_windows_updater() except: pass @@ -536,13 +459,10 @@ def start_comfyui(asyncio_loop=None): comfyui_manager.start() hook_breaker_ac10a0.save_functions() - asyncio_loop.run_until_complete( - nodes.init_extra_nodes( - init_custom_nodes=(not args.disable_all_custom_nodes) - or len(args.whitelist_custom_nodes) > 0, - init_api_nodes=not args.disable_api_nodes, - ) - ) + asyncio_loop.run_until_complete(nodes.init_extra_nodes( + init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, + init_api_nodes=not args.disable_api_nodes + )) hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() @@ -551,14 +471,7 @@ def start_comfyui(asyncio_loop=None): prompt_server.add_routes() hijack_progress(prompt_server) - threading.Thread( - target=prompt_worker, - daemon=True, - args=( - prompt_server.prompt_queue, - prompt_server, - ), - ).start() + threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() if args.quick_test_for_ci: exit(0) @@ -566,27 +479,18 @@ def start_comfyui(asyncio_loop=None): os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) call_on_start = None if args.auto_launch: - def startup_server(scheme, address, port): import webbrowser - - if os.name == "nt" and address == "0.0.0.0": - address = "127.0.0.1" - if ":" in address: + if os.name == 'nt' and address == '0.0.0.0': + address = '127.0.0.1' + if ':' in address: address = "[{}]".format(address) webbrowser.open(f"{scheme}://{address}:{port}") - call_on_start = startup_server async def start_all(): await prompt_server.setup() - await run( - prompt_server, - address=args.listen, - port=args.port, - verbose=not args.dont_print_server, - call_on_start=call_on_start, - ) + await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) # Returning these so that other code can integrate with the ComfyUI loop and server return asyncio_loop, prompt_server, start_all @@ -598,16 +502,12 @@ if __name__ == "__main__": logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) for package in ("comfy-aimdo", "comfy-kitchen"): try: - logging.info( - "{} version: {}".format(package, importlib.metadata.version(package)) - ) + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) except: pass if sys.version_info.major == 3 and sys.version_info.minor < 10: - logging.warning( - "WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended." - ) + logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") if args.disable_dynamic_vram: logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")