diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..e174c6e6b --- /dev/null +++ b/.editorconfig @@ -0,0 +1,4 @@ +root = true + +[*] +max_line_length = off \ No newline at end of file diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 73f8dc42f..95b58bc7c 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -11,7 +11,9 @@ import typing from typing import List, Optional, Tuple, Union from typing_extensions import TypedDict import torch +import lazy_object_proxy +from .. import interruption from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.executor_types import ExecutorToClientProgress @@ -21,7 +23,7 @@ from ..nodes.package import import_all_nodes_in_workspace # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # various functions that are declared here. It should have been a context in the first place. -nodes: ExportedNodes = import_all_nodes_in_workspace() +nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None): if extra_data is None: @@ -80,16 +82,16 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): results = [] if input_is_list: if allow_interrupt: - model_management.throw_exception_if_processing_interrupted() + interruption.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)(**input_data_all)) elif max_len_input == 0: if allow_interrupt: - model_management.throw_exception_if_processing_interrupted() + interruption.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)()) else: for i in range(max_len_input): if allow_interrupt: - model_management.throw_exception_if_processing_interrupted() + interruption.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results @@ -185,7 +187,7 @@ def recursive_execute(server: ExecutorToClientProgress, if server.client_id is not None: server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id) - except model_management.InterruptProcessingException as iex: + except interruption.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs @@ -327,7 +329,7 @@ class PromptExecutor: # First, send back the status to the frontend depending # on the exception type - if isinstance(ex, model_management.InterruptProcessingException): + if isinstance(ex, interruption.InterruptProcessingException): mes = { "prompt_id": prompt_id, "node_id": node_id, @@ -367,7 +369,7 @@ class PromptExecutor: execute_outputs = [] if extra_data is None: extra_data = {} - model_management.interrupt_current_processing(False) + interruption.interrupt_current_processing(False) if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 8c26b228c..c6a7fcd16 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -1,17 +1,79 @@ +import dataclasses import os +import posixpath import sys import time import logging -from typing import Optional +from typing import Optional, List, Set, Dict, Any, Iterator, Sequence from pkg_resources import resource_filename from ..cli_args import args -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) +supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) -folder_names_and_paths = {} +@dataclasses.dataclass +class FolderPathsTuple: + folder_name: str + paths: List[str] = dataclasses.field(default_factory=list) + supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions)) + + def __getitem__(self, item: Any): + if item == 0: + return self.paths + if item == 1: + return self.supported_extensions + else: + raise RuntimeError("unsupported tuple index") + + def __add__(self, other: "FolderPathsTuple"): + assert self.folder_name == other.folder_name + new_paths = list(frozenset(self.paths + other.paths)) + new_supported_extensions = self.supported_extensions | other.supported_extensions + return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions) + + def __iter__(self) -> Iterator[Sequence[str]]: + yield self.paths + yield self.supported_extensions + + +class FolderNames: + def __init__(self): + self.contents: Dict[str, FolderPathsTuple] = dict() + + def __getitem__(self, item) -> FolderPathsTuple: + if not isinstance(item, str): + raise RuntimeError("expected folder path") + if item not in self.contents: + self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set()) + return self.contents[item] + + def __setitem__(self, key: str, value: FolderPathsTuple): + assert isinstance(key, str) + if isinstance(value, tuple): + paths, supported_extensions = value + value = FolderPathsTuple(key, paths, supported_extensions) + if key in self.contents: + value = self.contents[key] + value + self.contents[key] = value + + def __len__(self): + return len(self.contents) + + def items(self): + return self.contents.items() + + def values(self): + return self.contents.values() + + def keys(self): + return self.contents.keys() + + +folder_names_and_paths = FolderNames() + +# todo: this should be initialized elsewhere if 'main.py' in sys.argv: base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) elif args.cwd is not None: @@ -25,31 +87,31 @@ elif args.cwd is not None: else: base_path = os.getcwd() models_dir = os.path.join(base_path, "models") -folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions) -folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], set([".yaml"])) -folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) -folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) -folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) -folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) -folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) -folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) -folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) -folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) -folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions) -folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) -folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) -folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) -folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) -folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) -folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) -folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) +folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions)) +folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"}) +folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions)) +folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions)) +folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions)) +folder_names_and_paths["unet"] = FolderPathsTuple("unet", [os.path.join(models_dir, "unet")], set(supported_pt_extensions)) +folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions)) +folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions)) +folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions)) +folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"}) +folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions)) +folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions)) +folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions)) +folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions)) +folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(base_path, "custom_nodes")], set()) +folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions)) +folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions)) +folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""}) output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") user_directory = os.path.join(base_path, "user") -filename_list_cache = {} +_filename_list_cache = {} if not os.path.exists(input_directory): try: @@ -57,32 +119,38 @@ if not os.path.exists(input_directory): except: logging.error("Failed to create input directory") + def set_output_directory(output_dir): global output_directory output_directory = output_dir + def set_temp_directory(temp_dir): global temp_directory temp_directory = temp_dir + def set_input_directory(input_dir): global input_directory input_directory = input_dir + def get_output_directory(): global output_directory return output_directory + def get_temp_directory(): global temp_directory return temp_directory + def get_input_directory(): global input_directory return input_directory -#NOTE: used in http server so don't put folders that should not be accessed remotely +# NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): if type_name == "output": return get_output_directory() @@ -133,17 +201,29 @@ def exists_annotated_filepath(name): return os.path.exists(filepath) -def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None): +def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -> str: + """ + Registers a model path for the given canonical name. + :param folder_name: the folder name + :param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as + a subpath to the default models directory + :return: the folder path + """ global folder_names_and_paths if full_folder_path is None: full_folder_path = os.path.join(models_dir, folder_name) - if folder_name in folder_names_and_paths and full_folder_path not in folder_names_and_paths[folder_name][0]: - folder_names_and_paths[folder_name][0].append(full_folder_path) - else: - folder_names_and_paths[folder_name] = ([full_folder_path], set()) -def get_folder_paths(folder_name): - return folder_names_and_paths[folder_name][0][:] + folder_path = folder_names_and_paths[folder_name] + if full_folder_path not in folder_path.paths: + folder_path.paths.append(full_folder_path) + + invalidate_cache(folder_name) + return full_folder_path + + +def get_folder_paths(folder_name) -> List[str]: + return folder_names_and_paths[folder_name].paths[:] + def recursive_search(directory, excluded_dir_names=None): if not os.path.isdir(directory): @@ -176,24 +256,41 @@ def recursive_search(directory, excluded_dir_names=None): continue return result, dirs + def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) - def get_full_path(folder_name, filename): + """ + Gets the path to a filename inside a folder. + + Works with untrusted filenames. + :param folder_name: + :param filename: + :return: + """ global folder_names_and_paths - if folder_name not in folder_names_and_paths: - return None - folders = folder_names_and_paths[folder_name] - filename = os.path.relpath(os.path.join("/", filename), "/") - for x in folders[0]: - full_path = os.path.join(x, filename) - if os.path.isfile(full_path): - return full_path + folders = folder_names_and_paths[folder_name].paths + filename_split = os.path.split(filename) + + trusted_paths = [] + for folder in folders: + folder_split = os.path.split(folder) + abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split)) + abs_folder_path = os.path.abspath(folder) + if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path: + trusted_paths.append(abs_file_path) + else: + logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}") + + for trusted_path in trusted_paths: + if os.path.isfile(trusted_path): + return trusted_path return None + def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() @@ -204,14 +301,15 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders, time.perf_counter()) + return sorted(list(output_list)), output_folders, time.perf_counter() + def cached_filename_list_(folder_name): - global filename_list_cache + global _filename_list_cache global folder_names_and_paths - if folder_name not in filename_list_cache: + if folder_name not in _filename_list_cache: return None - out = filename_list_cache[folder_name] + out = _filename_list_cache[folder_name] for x in out[1]: time_modified = out[1][x] @@ -227,14 +325,16 @@ def cached_filename_list_(folder_name): return out + def get_filename_list(folder_name): out = cached_filename_list_(folder_name) if out is None: out = get_filename_list_(folder_name) - global filename_list_cache - filename_list_cache[folder_name] = out + global _filename_list_cache + _filename_list_cache[folder_name] = out return list(out[0]) + def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): prefix_len = len(os.path.basename(filename_prefix)) @@ -255,7 +355,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height subfolder = os.path.dirname(os.path.normpath(filename_prefix)) filename = os.path.basename(os.path.normpath(filename_prefix)) - full_output_folder = os.path.join(output_dir, subfolder) + full_output_folder = str(os.path.join(output_dir, subfolder)) if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: err = "**** ERROR: Saving image outside the output folder is not allowed." + \ @@ -276,8 +376,14 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height def create_directories(): - for _, (paths, _) in folder_names_and_paths.items(): - default_path = paths[0] - os.makedirs(default_path, exist_ok=True) + # all configured paths should be created + for folder_path_spec in folder_names_and_paths.values(): + for path in folder_path_spec.paths: + os.makedirs(path, exist_ok=True) for path in (temp_directory, input_directory, output_directory, user_directory): os.makedirs(path, exist_ok=True) + + +def invalidate_cache(folder_name): + global _filename_list_cache + _filename_list_cache.pop(folder_name, None) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index b9906d089..f144738cd 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,94 +1,26 @@ -from .. import options # Suppress warnings during import -import warnings - -warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") -options.enable_args_parsing() - +import asyncio +import gc +import itertools import logging import os -import importlib.util - -from ..cmd import cuda_malloc -from ..cmd import folder_paths -from .extra_model_paths import load_extra_path_config -from ..analytics.analytics import initialize_event_tracking -from ..nodes.package import import_all_nodes_in_workspace - -import time - - -def execute_prestartup_script(): - def execute_script(script_path): - module_name = os.path.splitext(script_path)[0] - try: - spec = importlib.util.spec_from_file_location(module_name, script_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return True - except Exception as e: - logging.error(f"Failed to execute startup-script: {script_path} / {e}") - return False - - node_paths = folder_paths.get_folder_paths("custom_nodes") - node_prestartup_times = [] - for custom_node_path in node_paths: - possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else [] - - for possible_module in possible_modules: - module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": - continue - - script_path = os.path.join(module_path, "prestartup_script.py") - if os.path.exists(script_path): - time_before = time.perf_counter() - success = execute_script(script_path) - node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) - if len(node_prestartup_times) > 0: - logging.info("\nPrestartup times for custom nodes:") - for n in sorted(node_prestartup_times): - if n[2]: - import_message = "" - else: - import_message = " (PRESTARTUP FAILED)" - logging.info("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - - -execute_prestartup_script() - -# Main code -import asyncio -import itertools import shutil import threading -import gc - -from ..cli_args import args - -if os.name == "nt": - import logging - - logging.getLogger("xformers").addFilter( - lambda record: 'A matching Triton is not available' not in record.getMessage()) - -if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - logging.info("Set cuda device to:", args.cuda_device) - -if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - -from .. import utils +import time +from comfy.utils import hijack_progress +from .extra_model_paths import load_extra_path_config +from .main_pre import args +from .. import model_management +from ..analytics.analytics import initialize_event_tracking +from ..cmd import cuda_malloc +from ..cmd import folder_paths from ..cmd import server as server_module from ..component_model.abstract_prompt_queue import AbstractPromptQueue -from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus -from .. import model_management +from ..component_model.queue_types import ExecutionStatus from ..distributed.distributed_prompt_queue import DistributedPromptQueue -from ..component_model.executor_types import ExecutorToClientProgress from ..distributed.server_stub import ServerStub +from ..nodes.package import import_all_nodes_in_workspace def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): @@ -120,7 +52,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): completed=e.success, messages=e.status_messages)) if _server.client_id is not None: - _server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id) + _server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -152,18 +84,6 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) -def hijack_progress(server: ExecutorToClientProgress): - def hook(value: float, total: float, preview_image): - model_management.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} - - server.send_sync("progress", progress, server.client_id) - if preview_image is not None: - server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) - - utils.set_progress_bar_global_hook(hook) - - def cleanup_temp(): try: temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @@ -194,6 +114,18 @@ async def main(): folder_paths.set_temp_directory(temp_dir) cleanup_temp() + # configure extra model paths earlier + try: + extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + if os.path.isfile(extra_model_paths_config_path): + load_extra_path_config(extra_model_paths_config_path) + except NameError: + pass + + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) + # create the default directories if we're instructed to, then exit # or, if it's a windows standalone build, the single .exe file should have its side-by-side directories always created if args.create_directories: @@ -208,18 +140,6 @@ async def main(): except: pass - # configure extra model paths earlier - try: - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") - if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) - except NameError: - pass - - if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) - loop = asyncio.get_event_loop() server = server_module.PromptServer(loop) if args.external_address is not None: @@ -246,7 +166,6 @@ async def main(): q = PromptQueue(server) server.prompt_queue = q - server.add_routes() hijack_progress(server) cuda_malloc_warning() @@ -293,7 +212,7 @@ async def main(): server.port = args.port try: await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, - call_on_start=call_on_start) + call_on_start=call_on_start) except asyncio.CancelledError: if distributed: await q.close() diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py new file mode 100644 index 000000000..b78ddf231 --- /dev/null +++ b/comfy/cmd/main_pre.py @@ -0,0 +1,23 @@ +import os + +from .. import options + +import warnings +import logging + +options.enable_args_parsing() +if os.name == "nt": + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) +warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") + +from ..cli_args import args + +if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + logging.info("Set cuda device to:", args.cuda_device) + +if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + +__all__ = ["args"] diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 9c29a91a0..fa7c1cce1 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -23,6 +23,7 @@ import aiofiles import aiohttp from aiohttp import web +import comfy.interruption from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from ..cmd import execution from ..cmd import folder_paths @@ -515,7 +516,7 @@ class PromptServer(ExecutorToClientProgress): @routes.post("/interrupt") async def post_interrupt(request): - model_management.interrupt_current_processing() + comfy.interruption.interrupt_current_processing() return web.Response(status=200) @routes.post("/free") diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 048b15c18..49b0054ab 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -10,6 +10,7 @@ from aio_pika.patterns import RPC from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress from ..component_model.queue_types import BinaryEventTypes +from ..utils import hijack_progress async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None, @@ -24,7 +25,7 @@ def _get_name(queue_name: str, user_id: str) -> str: class DistributedExecutorToClientProgress(ExecutorToClientProgress): - def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop): + def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True): self._rpc = rpc self._queue_name = queue_name self._loop = loop @@ -32,6 +33,8 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): self.client_id = None self.node_id = None self.last_node_id = None + if receive_all_progress_notifications: + hijack_progress(self) async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: # for now, do not send binary data this way, since it cannot be json serialized / it's impractical diff --git a/comfy/interruption.py b/comfy/interruption.py new file mode 100644 index 000000000..5c4cd0f7d --- /dev/null +++ b/comfy/interruption.py @@ -0,0 +1,31 @@ +import threading + +_interrupt_processing_mutex = threading.RLock() +_interrupt_processing = False + + +class InterruptProcessingException(Exception): + pass + + +def interrupt_current_processing(value=True): + global _interrupt_processing + global _interrupt_processing_mutex + with _interrupt_processing_mutex: + _interrupt_processing = value + + +def processing_interrupted(): + global _interrupt_processing + global _interrupt_processing_mutex + with _interrupt_processing_mutex: + return _interrupt_processing + + +def throw_exception_if_processing_interrupted(): + global _interrupt_processing + global _interrupt_processing_mutex + with _interrupt_processing_mutex: + if _interrupt_processing: + _interrupt_processing = False + raise InterruptProcessingException() diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 5adb95f33..c294c3728 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -1,15 +1,18 @@ from __future__ import annotations import logging +import os +from itertools import chain from os.path import join -from typing import List, Any, Optional +from typing import List, Any, Optional, Union from huggingface_hub import hf_hub_download from requests import Session from .cmd import folder_paths -from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse -from .utils import comfy_tqdm, ProgressBar +from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_ +from .interruption import InterruptProcessingException +from .utils import ProgressBar, comfy_tqdm session = Session() @@ -20,32 +23,40 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any] return sorted(list(existing | downloadable)) -def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> str: +def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> Optional[str]: path = folder_paths.get_full_path(folder_name, filename) if path is None: try: + # todo: should this be the first or last path? destination = folder_paths.get_folder_paths(folder_name)[0] - known_file = next(f for f in known_files if str(f) == filename) + known_file: Optional[HuggingFile | CivitFile] = None + for candidate in known_files: + if candidate.filename == filename or filename in candidate.alternate_filenames or filename == candidate.save_with_filename: + known_file = candidate + break + if known_file is None: + return path with comfy_tqdm(): if isinstance(known_file, HuggingFile): + save_filename = known_file.save_with_filename or known_file.filename path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, + filename=save_filename, local_dir=destination, resume_download=True) else: url: Optional[str] = None + save_filename = known_file.save_with_filename or known_file.filename if isinstance(known_file, CivitFile): model_info_res = session.get( f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}") model_info: CivitModelsGetResponse = model_info_res.json() - for model_version in model_info['modelVersions']: - for file in model_version['files']: - if file['name'] == filename: - url = file['downloadUrl'] - break - if url is not None: + + file: CivitFile_ + for file in chain.from_iterable(version['files'] for version in model_info['modelVersions']): + if file['name'] == filename: + url = file['downloadUrl'] break else: raise RuntimeError("unknown file type") @@ -53,19 +64,29 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi if url is None: logging.warning(f"Could not retrieve file {str(known_file)}") else: - with session.get(url, stream=True, allow_redirects=True) as response: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = ProgressBar(total=total_size) - with open(join(destination, filename), "wb") as file: - for chunk in response.iter_content(chunk_size=512 * 1024): - progress_bar.update(len(chunk)) - file.write(chunk) + destination_with_filename = join(destination, save_filename) + try: + + with session.get(url, stream=True, allow_redirects=True) as response: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = ProgressBar(total=total_size) + with open(destination_with_filename, "wb") as file: + for chunk in response.iter_content(chunk_size=512 * 1024): + progress_bar.update(len(chunk)) + file.write(chunk) + except InterruptProcessingException: + os.remove(destination_with_filename) + path = folder_paths.get_full_path(folder_name, filename) assert path is not None except StopIteration: pass except Exception as exc: logging.error("Error while trying to download a file", exc_info=exc) + finally: + # a path was found for any reason, so we should invalidate the cache + if path is not None: + folder_paths.invalidate_cache(folder_name) return path @@ -118,5 +139,10 @@ KNOWN_CLIP_VISION_MODELS = [ KNOWN_LORAS = [ CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"), - # todo: a lot of the slider loras are useful and should also be included -] \ No newline at end of file +] + + +def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: + symbol += models + folder_paths.invalidate_cache(folder_name) + return symbol diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 99f11a6be..2c5d522a2 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -22,6 +22,14 @@ class CivitFile: def __str__(self): return self.filename + @property + def save_with_filename(self): + return self.filename + + @property + def alternate_filenames(self): + return [] + @dataclasses.dataclass class HuggingFile: @@ -35,6 +43,8 @@ class HuggingFile: """ repo_id: str filename: str + save_with_filename: Optional[str] = None + alternate_filenames: List[str] = dataclasses.field(default_factory=list) show_in_ui: Optional[bool] = True def __str__(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index f6b1f9426..e4891afdf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -2,7 +2,7 @@ import psutil import logging from enum import Enum from .cli_args import args -from . import utils +from . import interruption from threading import RLock import torch @@ -840,31 +840,14 @@ def unload_all_models(): def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight -#TODO: might be cleaner to put this somewhere else -import threading -class InterruptProcessingException(Exception): - pass - -interrupt_processing_mutex = threading.RLock() - -interrupt_processing = False def interrupt_current_processing(value=True): - global interrupt_processing - global interrupt_processing_mutex - with interrupt_processing_mutex: - interrupt_processing = value + interruption.interrupt_current_processing(value) + def processing_interrupted(): - global interrupt_processing - global interrupt_processing_mutex - with interrupt_processing_mutex: - return interrupt_processing + interruption.processing_interrupted() + def throw_exception_if_processing_interrupted(): - global interrupt_processing - global interrupt_processing_mutex - with interrupt_processing_mutex: - if interrupt_processing: - interrupt_processing = False - raise InterruptProcessingException() + interruption.throw_exception_if_processing_interrupted() diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 450682776..6c70b803f 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -25,7 +25,6 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS -from ..model_downloader_types import HuggingFile from ..nodes.common import MAX_RESOLUTION from .. import controlnet diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 75ab5b853..4a8c8b721 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -10,11 +10,8 @@ from functools import reduce from importlib.metadata import entry_points from pkg_resources import resource_filename - -from comfy_extras import nodes as comfy_extras_nodes -from . import base_nodes from .package_typing import ExportedNodes -from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes + _comfy_nodes: ExportedNodes = ExportedNodes() @@ -77,6 +74,10 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: + # now actually import the nodes, to improve control of node loading order + from comfy_extras import nodes as comfy_extras_nodes + from . import base_nodes + from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes # only load these nodes once if len(_comfy_nodes) == 0: base_and_extra = reduce(lambda x, y: x.update(y), diff --git a/comfy/utils.py b/comfy/utils.py index cb843dc28..ab41799e3 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,16 +1,26 @@ import os.path +from contextlib import contextmanager import torch import math import struct -from . import checkpoint_pickle + +from tqdm import tqdm + +from . import checkpoint_pickle, interruption import safetensors.torch import numpy as np from PIL import Image -from tqdm import tqdm -from contextlib import contextmanager import logging +from .component_model.executor_types import ExecutorToClientProgress +from .component_model.queue_types import BinaryEventTypes + +PROGRESS_BAR_ENABLED = True +PROGRESS_BAR_HOOK = None + + + def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") @@ -457,16 +467,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output -PROGRESS_BAR_ENABLED = True + +def hijack_progress(server: ExecutorToClientProgress): + def hook(value: float, total: float, preview_image): + interruption.throw_exception_if_processing_interrupted() + progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + + server.send_sync("progress", progress, server.client_id) + if preview_image is not None: + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) + + set_progress_bar_global_hook(hook) + + def set_progress_bar_enabled(enabled): global PROGRESS_BAR_ENABLED PROGRESS_BAR_ENABLED = enabled -PROGRESS_BAR_HOOK = None + +def get_progress_bar_enabled() -> bool: + return PROGRESS_BAR_ENABLED + + def set_progress_bar_global_hook(function): global PROGRESS_BAR_HOOK PROGRESS_BAR_HOOK = function + class ProgressBar: def __init__(self, total: float): global PROGRESS_BAR_HOOK diff --git a/requirements.txt b/requirements.txt index ce9d90b2a..615d4beb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,4 +31,5 @@ aio-pika pyjwt[crypto] kornia>=0.7.1 mpmath>=1.0,!=1.4.0a0 -huggingface_hub \ No newline at end of file +huggingface_hub +lazy-object-proxy \ No newline at end of file