Improvements to node loading, node API, folder paths and progress

- Improve node loading order. It now occurs "as late as possible".
   Configuration should be exposed as per the README.
 - Added methods to specify custom folders and models used in examples
   more robustly for custom nodes.
 - Downloading models can now be gracefully interrupted.
 - Progress notifications are now sent over the network for distributed
   ComfyUI operations.
 - Python objects have been moved around to prevent less transitive
   package importing issues.
This commit is contained in:
doctorpangloss 2024-03-13 16:14:18 -07:00
parent 3ccbda36da
commit 341c9f2e90
15 changed files with 357 additions and 221 deletions

4
.editorconfig Normal file
View File

@ -0,0 +1,4 @@
root = true
[*]
max_line_length = off

View File

@ -11,7 +11,9 @@ import typing
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from typing_extensions import TypedDict from typing_extensions import TypedDict
import torch import torch
import lazy_object_proxy
from .. import interruption
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
@ -21,7 +23,7 @@ from ..nodes.package import import_all_nodes_in_workspace
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place. # various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = import_all_nodes_in_workspace() nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None): def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
if extra_data is None: if extra_data is None:
@ -80,16 +82,16 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
results = [] results = []
if input_is_list: if input_is_list:
if allow_interrupt: if allow_interrupt:
model_management.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)(**input_data_all)) results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0: elif max_len_input == 0:
if allow_interrupt: if allow_interrupt:
model_management.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)()) results.append(getattr(obj, func)())
else: else:
for i in range(max_len_input): for i in range(max_len_input):
if allow_interrupt: if allow_interrupt:
model_management.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results return results
@ -185,7 +187,7 @@ def recursive_execute(server: ExecutorToClientProgress,
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
server.client_id) server.client_id)
except model_management.InterruptProcessingException as iex: except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
# skip formatting inputs/outputs # skip formatting inputs/outputs
@ -327,7 +329,7 @@ class PromptExecutor:
# First, send back the status to the frontend depending # First, send back the status to the frontend depending
# on the exception type # on the exception type
if isinstance(ex, model_management.InterruptProcessingException): if isinstance(ex, interruption.InterruptProcessingException):
mes = { mes = {
"prompt_id": prompt_id, "prompt_id": prompt_id,
"node_id": node_id, "node_id": node_id,
@ -367,7 +369,7 @@ class PromptExecutor:
execute_outputs = [] execute_outputs = []
if extra_data is None: if extra_data is None:
extra_data = {} extra_data = {}
model_management.interrupt_current_processing(False) interruption.interrupt_current_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"] self.server.client_id = extra_data["client_id"]

View File

@ -1,17 +1,79 @@
import dataclasses
import os import os
import posixpath
import sys import sys
import time import time
import logging import logging
from typing import Optional from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
from pkg_resources import resource_filename from pkg_resources import resource_filename
from ..cli_args import args from ..cli_args import args
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
folder_names_and_paths = {}
@dataclasses.dataclass
class FolderPathsTuple:
folder_name: str
paths: List[str] = dataclasses.field(default_factory=list)
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
def __getitem__(self, item: Any):
if item == 0:
return self.paths
if item == 1:
return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"):
assert self.folder_name == other.folder_name
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
def __iter__(self) -> Iterator[Sequence[str]]:
yield self.paths
yield self.supported_extensions
class FolderNames:
def __init__(self):
self.contents: Dict[str, FolderPathsTuple] = dict()
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path")
if item not in self.contents:
self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set())
return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple):
assert isinstance(key, str)
if isinstance(value, tuple):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
if key in self.contents:
value = self.contents[key] + value
self.contents[key] = value
def __len__(self):
return len(self.contents)
def items(self):
return self.contents.items()
def values(self):
return self.contents.values()
def keys(self):
return self.contents.keys()
folder_names_and_paths = FolderNames()
# todo: this should be initialized elsewhere
if 'main.py' in sys.argv: if 'main.py' in sys.argv:
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
elif args.cwd is not None: elif args.cwd is not None:
@ -25,31 +87,31 @@ elif args.cwd is not None:
else: else:
base_path = os.getcwd() base_path = os.getcwd()
models_dir = os.path.join(base_path, "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions) folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], set([".yaml"])) folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) folder_names_and_paths["unet"] = FolderPathsTuple("unet", [os.path.join(models_dir, "unet")], set(supported_pt_extensions))
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions))
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions))
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions))
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"})
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions) folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions))
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions))
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions))
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions))
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(base_path, "custom_nodes")], set())
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")
user_directory = os.path.join(base_path, "user") user_directory = os.path.join(base_path, "user")
filename_list_cache = {} _filename_list_cache = {}
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
try: try:
@ -57,32 +119,38 @@ if not os.path.exists(input_directory):
except: except:
logging.error("Failed to create input directory") logging.error("Failed to create input directory")
def set_output_directory(output_dir): def set_output_directory(output_dir):
global output_directory global output_directory
output_directory = output_dir output_directory = output_dir
def set_temp_directory(temp_dir): def set_temp_directory(temp_dir):
global temp_directory global temp_directory
temp_directory = temp_dir temp_directory = temp_dir
def set_input_directory(input_dir): def set_input_directory(input_dir):
global input_directory global input_directory
input_directory = input_dir input_directory = input_dir
def get_output_directory(): def get_output_directory():
global output_directory global output_directory
return output_directory return output_directory
def get_temp_directory(): def get_temp_directory():
global temp_directory global temp_directory
return temp_directory return temp_directory
def get_input_directory(): def get_input_directory():
global input_directory global input_directory
return input_directory return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely # NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name): def get_directory_by_type(type_name):
if type_name == "output": if type_name == "output":
return get_output_directory() return get_output_directory()
@ -133,17 +201,29 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath) return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None): def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -> str:
"""
Registers a model path for the given canonical name.
:param folder_name: the folder name
:param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as
a subpath to the default models directory
:return: the folder path
"""
global folder_names_and_paths global folder_names_and_paths
if full_folder_path is None: if full_folder_path is None:
full_folder_path = os.path.join(models_dir, folder_name) full_folder_path = os.path.join(models_dir, folder_name)
if folder_name in folder_names_and_paths and full_folder_path not in folder_names_and_paths[folder_name][0]:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name): folder_path = folder_names_and_paths[folder_name]
return folder_names_and_paths[folder_name][0][:] if full_folder_path not in folder_path.paths:
folder_path.paths.append(full_folder_path)
invalidate_cache(folder_name)
return full_folder_path
def get_folder_paths(folder_name) -> List[str]:
return folder_names_and_paths[folder_name].paths[:]
def recursive_search(directory, excluded_dir_names=None): def recursive_search(directory, excluded_dir_names=None):
if not os.path.isdir(directory): if not os.path.isdir(directory):
@ -176,24 +256,41 @@ def recursive_search(directory, excluded_dir_names=None):
continue continue
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
def get_full_path(folder_name, filename): def get_full_path(folder_name, filename):
"""
Gets the path to a filename inside a folder.
Works with untrusted filenames.
:param folder_name:
:param filename:
:return:
"""
global folder_names_and_paths global folder_names_and_paths
if folder_name not in folder_names_and_paths: folders = folder_names_and_paths[folder_name].paths
return None filename_split = os.path.split(filename)
folders = folder_names_and_paths[folder_name]
filename = os.path.relpath(os.path.join("/", filename), "/") trusted_paths = []
for x in folders[0]: for folder in folders:
full_path = os.path.join(x, filename) folder_split = os.path.split(folder)
if os.path.isfile(full_path): abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split))
return full_path abs_folder_path = os.path.abspath(folder)
if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path:
trusted_paths.append(abs_file_path)
else:
logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}")
for trusted_path in trusted_paths:
if os.path.isfile(trusted_path):
return trusted_path
return None return None
def get_filename_list_(folder_name): def get_filename_list_(folder_name):
global folder_names_and_paths global folder_names_and_paths
output_list = set() output_list = set()
@ -204,14 +301,15 @@ def get_filename_list_(folder_name):
output_list.update(filter_files_extensions(files, folders[1])) output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all} output_folders = {**output_folders, **folders_all}
return (sorted(list(output_list)), output_folders, time.perf_counter()) return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name): def cached_filename_list_(folder_name):
global filename_list_cache global _filename_list_cache
global folder_names_and_paths global folder_names_and_paths
if folder_name not in filename_list_cache: if folder_name not in _filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = _filename_list_cache[folder_name]
for x in out[1]: for x in out[1]:
time_modified = out[1][x] time_modified = out[1][x]
@ -227,14 +325,16 @@ def cached_filename_list_(folder_name):
return out return out
def get_filename_list(folder_name): def get_filename_list(folder_name):
out = cached_filename_list_(folder_name) out = cached_filename_list_(folder_name)
if out is None: if out is None:
out = get_filename_list_(folder_name) out = get_filename_list_(folder_name)
global filename_list_cache global _filename_list_cache
filename_list_cache[folder_name] = out _filename_list_cache[folder_name] = out
return list(out[0]) return list(out[0])
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename): def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix)) prefix_len = len(os.path.basename(filename_prefix))
@ -255,7 +355,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
subfolder = os.path.dirname(os.path.normpath(filename_prefix)) subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix)) filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(output_dir, subfolder) full_output_folder = str(os.path.join(output_dir, subfolder))
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
err = "**** ERROR: Saving image outside the output folder is not allowed." + \ err = "**** ERROR: Saving image outside the output folder is not allowed." + \
@ -276,8 +376,14 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
def create_directories(): def create_directories():
for _, (paths, _) in folder_names_and_paths.items(): # all configured paths should be created
default_path = paths[0] for folder_path_spec in folder_names_and_paths.values():
os.makedirs(default_path, exist_ok=True) for path in folder_path_spec.paths:
os.makedirs(path, exist_ok=True)
for path in (temp_directory, input_directory, output_directory, user_directory): for path in (temp_directory, input_directory, output_directory, user_directory):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
def invalidate_cache(folder_name):
global _filename_list_cache
_filename_list_cache.pop(folder_name, None)

View File

@ -1,94 +1,26 @@
from .. import options
# Suppress warnings during import # Suppress warnings during import
import warnings import asyncio
import gc
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") import itertools
options.enable_args_parsing()
import logging import logging
import os import os
import importlib.util
from ..cmd import cuda_malloc
from ..cmd import folder_paths
from .extra_model_paths import load_extra_path_config
from ..analytics.analytics import initialize_event_tracking
from ..nodes.package import import_all_nodes_in_workspace
import time
def execute_prestartup_script():
def execute_script(script_path):
module_name = os.path.splitext(script_path)[0]
try:
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return True
except Exception as e:
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_prestartup_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
continue
script_path = os.path.join(module_path, "prestartup_script.py")
if os.path.exists(script_path):
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_prestartup_times) > 0:
logging.info("\nPrestartup times for custom nodes:")
for n in sorted(node_prestartup_times):
if n[2]:
import_message = ""
else:
import_message = " (PRESTARTUP FAILED)"
logging.info("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
execute_prestartup_script()
# Main code
import asyncio
import itertools
import shutil import shutil
import threading import threading
import gc import time
from ..cli_args import args
if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(
lambda record: 'A matching Triton is not available' not in record.getMessage())
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to:", args.cuda_device)
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
from .. import utils
from comfy.utils import hijack_progress
from .extra_model_paths import load_extra_path_config
from .main_pre import args
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc
from ..cmd import folder_paths
from ..cmd import server as server_module from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus from ..component_model.queue_types import ExecutionStatus
from .. import model_management
from ..distributed.distributed_prompt_queue import DistributedPromptQueue from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
from ..nodes.package import import_all_nodes_in_workspace
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
@ -120,7 +52,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
completed=e.success, completed=e.success,
messages=e.status_messages)) messages=e.status_messages))
if _server.client_id is not None: if _server.client_id is not None:
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id) _server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
@ -152,18 +84,6 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server: ExecutorToClientProgress):
def hook(value: float, total: float, preview_image):
model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
utils.set_progress_bar_global_hook(hook)
def cleanup_temp(): def cleanup_temp():
try: try:
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
@ -194,6 +114,18 @@ async def main():
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
# configure extra model paths earlier
try:
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
except NameError:
pass
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
# create the default directories if we're instructed to, then exit # create the default directories if we're instructed to, then exit
# or, if it's a windows standalone build, the single .exe file should have its side-by-side directories always created # or, if it's a windows standalone build, the single .exe file should have its side-by-side directories always created
if args.create_directories: if args.create_directories:
@ -208,18 +140,6 @@ async def main():
except: except:
pass pass
# configure extra model paths earlier
try:
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
except NameError:
pass
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
server = server_module.PromptServer(loop) server = server_module.PromptServer(loop)
if args.external_address is not None: if args.external_address is not None:
@ -246,7 +166,6 @@ async def main():
q = PromptQueue(server) q = PromptQueue(server)
server.prompt_queue = q server.prompt_queue = q
server.add_routes() server.add_routes()
hijack_progress(server) hijack_progress(server)
cuda_malloc_warning() cuda_malloc_warning()
@ -293,7 +212,7 @@ async def main():
server.port = args.port server.port = args.port
try: try:
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
call_on_start=call_on_start) call_on_start=call_on_start)
except asyncio.CancelledError: except asyncio.CancelledError:
if distributed: if distributed:
await q.close() await q.close()

23
comfy/cmd/main_pre.py Normal file
View File

@ -0,0 +1,23 @@
import os
from .. import options
import warnings
import logging
options.enable_args_parsing()
if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
from ..cli_args import args
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to:", args.cuda_device)
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
__all__ = ["args"]

View File

@ -23,6 +23,7 @@ import aiofiles
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
import comfy.interruption
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..cmd import execution from ..cmd import execution
from ..cmd import folder_paths from ..cmd import folder_paths
@ -515,7 +516,7 @@ class PromptServer(ExecutorToClientProgress):
@routes.post("/interrupt") @routes.post("/interrupt")
async def post_interrupt(request): async def post_interrupt(request):
model_management.interrupt_current_processing() comfy.interruption.interrupt_current_processing()
return web.Response(status=200) return web.Response(status=200)
@routes.post("/free") @routes.post("/free")

View File

@ -10,6 +10,7 @@ from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
from ..component_model.queue_types import BinaryEventTypes from ..component_model.queue_types import BinaryEventTypes
from ..utils import hijack_progress
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None, async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
@ -24,7 +25,7 @@ def _get_name(queue_name: str, user_id: str) -> str:
class DistributedExecutorToClientProgress(ExecutorToClientProgress): class DistributedExecutorToClientProgress(ExecutorToClientProgress):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop): def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
self._rpc = rpc self._rpc = rpc
self._queue_name = queue_name self._queue_name = queue_name
self._loop = loop self._loop = loop
@ -32,6 +33,8 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
self.client_id = None self.client_id = None
self.node_id = None self.node_id = None
self.last_node_id = None self.last_node_id = None
if receive_all_progress_notifications:
hijack_progress(self)
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical # for now, do not send binary data this way, since it cannot be json serialized / it's impractical

31
comfy/interruption.py Normal file
View File

@ -0,0 +1,31 @@
import threading
_interrupt_processing_mutex = threading.RLock()
_interrupt_processing = False
class InterruptProcessingException(Exception):
pass
def interrupt_current_processing(value=True):
global _interrupt_processing
global _interrupt_processing_mutex
with _interrupt_processing_mutex:
_interrupt_processing = value
def processing_interrupted():
global _interrupt_processing
global _interrupt_processing_mutex
with _interrupt_processing_mutex:
return _interrupt_processing
def throw_exception_if_processing_interrupted():
global _interrupt_processing
global _interrupt_processing_mutex
with _interrupt_processing_mutex:
if _interrupt_processing:
_interrupt_processing = False
raise InterruptProcessingException()

View File

@ -1,15 +1,18 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from itertools import chain
from os.path import join from os.path import join
from typing import List, Any, Optional from typing import List, Any, Optional, Union
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import Session from requests import Session
from .cmd import folder_paths from .cmd import folder_paths
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
from .utils import comfy_tqdm, ProgressBar from .interruption import InterruptProcessingException
from .utils import ProgressBar, comfy_tqdm
session = Session() session = Session()
@ -20,32 +23,40 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]
return sorted(list(existing | downloadable)) return sorted(list(existing | downloadable))
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> str: def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> Optional[str]:
path = folder_paths.get_full_path(folder_name, filename) path = folder_paths.get_full_path(folder_name, filename)
if path is None: if path is None:
try: try:
# todo: should this be the first or last path?
destination = folder_paths.get_folder_paths(folder_name)[0] destination = folder_paths.get_folder_paths(folder_name)[0]
known_file = next(f for f in known_files if str(f) == filename) known_file: Optional[HuggingFile | CivitFile] = None
for candidate in known_files:
if candidate.filename == filename or filename in candidate.alternate_filenames or filename == candidate.save_with_filename:
known_file = candidate
break
if known_file is None:
return path
with comfy_tqdm(): with comfy_tqdm():
if isinstance(known_file, HuggingFile): if isinstance(known_file, HuggingFile):
save_filename = known_file.save_with_filename or known_file.filename
path = hf_hub_download(repo_id=known_file.repo_id, path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename, filename=save_filename,
local_dir=destination, local_dir=destination,
resume_download=True) resume_download=True)
else: else:
url: Optional[str] = None url: Optional[str] = None
save_filename = known_file.save_with_filename or known_file.filename
if isinstance(known_file, CivitFile): if isinstance(known_file, CivitFile):
model_info_res = session.get( model_info_res = session.get(
f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}") f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}")
model_info: CivitModelsGetResponse = model_info_res.json() model_info: CivitModelsGetResponse = model_info_res.json()
for model_version in model_info['modelVersions']:
for file in model_version['files']: file: CivitFile_
if file['name'] == filename: for file in chain.from_iterable(version['files'] for version in model_info['modelVersions']):
url = file['downloadUrl'] if file['name'] == filename:
break url = file['downloadUrl']
if url is not None:
break break
else: else:
raise RuntimeError("unknown file type") raise RuntimeError("unknown file type")
@ -53,19 +64,29 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
if url is None: if url is None:
logging.warning(f"Could not retrieve file {str(known_file)}") logging.warning(f"Could not retrieve file {str(known_file)}")
else: else:
with session.get(url, stream=True, allow_redirects=True) as response: destination_with_filename = join(destination, save_filename)
total_size = int(response.headers.get("content-length", 0)) try:
progress_bar = ProgressBar(total=total_size)
with open(join(destination, filename), "wb") as file: with session.get(url, stream=True, allow_redirects=True) as response:
for chunk in response.iter_content(chunk_size=512 * 1024): total_size = int(response.headers.get("content-length", 0))
progress_bar.update(len(chunk)) progress_bar = ProgressBar(total=total_size)
file.write(chunk) with open(destination_with_filename, "wb") as file:
for chunk in response.iter_content(chunk_size=512 * 1024):
progress_bar.update(len(chunk))
file.write(chunk)
except InterruptProcessingException:
os.remove(destination_with_filename)
path = folder_paths.get_full_path(folder_name, filename) path = folder_paths.get_full_path(folder_name, filename)
assert path is not None assert path is not None
except StopIteration: except StopIteration:
pass pass
except Exception as exc: except Exception as exc:
logging.error("Error while trying to download a file", exc_info=exc) logging.error("Error while trying to download a file", exc_info=exc)
finally:
# a path was found for any reason, so we should invalidate the cache
if path is not None:
folder_paths.invalidate_cache(folder_name)
return path return path
@ -118,5 +139,10 @@ KNOWN_CLIP_VISION_MODELS = [
KNOWN_LORAS = [ KNOWN_LORAS = [
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"), CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
# todo: a lot of the slider loras are useful and should also be included ]
]
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
symbol += models
folder_paths.invalidate_cache(folder_name)
return symbol

View File

@ -22,6 +22,14 @@ class CivitFile:
def __str__(self): def __str__(self):
return self.filename return self.filename
@property
def save_with_filename(self):
return self.filename
@property
def alternate_filenames(self):
return []
@dataclasses.dataclass @dataclasses.dataclass
class HuggingFile: class HuggingFile:
@ -35,6 +43,8 @@ class HuggingFile:
""" """
repo_id: str repo_id: str
filename: str filename: str
save_with_filename: Optional[str] = None
alternate_filenames: List[str] = dataclasses.field(default_factory=list)
show_in_ui: Optional[bool] = True show_in_ui: Optional[bool] = True
def __str__(self): def __str__(self):

View File

@ -2,7 +2,7 @@ import psutil
import logging import logging
from enum import Enum from enum import Enum
from .cli_args import args from .cli_args import args
from . import utils from . import interruption
from threading import RLock from threading import RLock
import torch import torch
@ -840,31 +840,14 @@ def unload_all_models():
def resolve_lowvram_weight(weight, model, key): #TODO: remove def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight return weight
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass
interrupt_processing_mutex = threading.RLock()
interrupt_processing = False
def interrupt_current_processing(value=True): def interrupt_current_processing(value=True):
global interrupt_processing interruption.interrupt_current_processing(value)
global interrupt_processing_mutex
with interrupt_processing_mutex:
interrupt_processing = value
def processing_interrupted(): def processing_interrupted():
global interrupt_processing interruption.processing_interrupted()
global interrupt_processing_mutex
with interrupt_processing_mutex:
return interrupt_processing
def throw_exception_if_processing_interrupted(): def throw_exception_if_processing_interrupted():
global interrupt_processing interruption.throw_exception_if_processing_interrupted()
global interrupt_processing_mutex
with interrupt_processing_mutex:
if interrupt_processing:
interrupt_processing = False
raise InterruptProcessingException()

View File

@ -25,7 +25,6 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS
from ..model_downloader_types import HuggingFile
from ..nodes.common import MAX_RESOLUTION from ..nodes.common import MAX_RESOLUTION
from .. import controlnet from .. import controlnet

View File

@ -10,11 +10,8 @@ from functools import reduce
from importlib.metadata import entry_points from importlib.metadata import entry_points
from pkg_resources import resource_filename from pkg_resources import resource_filename
from comfy_extras import nodes as comfy_extras_nodes
from . import base_nodes
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
_comfy_nodes: ExportedNodes = ExportedNodes() _comfy_nodes: ExportedNodes = ExportedNodes()
@ -77,6 +74,10 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import
def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order
from comfy_extras import nodes as comfy_extras_nodes
from . import base_nodes
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
# only load these nodes once # only load these nodes once
if len(_comfy_nodes) == 0: if len(_comfy_nodes) == 0:
base_and_extra = reduce(lambda x, y: x.update(y), base_and_extra = reduce(lambda x, y: x.update(y),

View File

@ -1,16 +1,26 @@
import os.path import os.path
from contextlib import contextmanager
import torch import torch
import math import math
import struct import struct
from . import checkpoint_pickle
from tqdm import tqdm
from . import checkpoint_pickle, interruption
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm
from contextlib import contextmanager
import logging import logging
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.queue_types import BinaryEventTypes
PROGRESS_BAR_ENABLED = True
PROGRESS_BAR_HOOK = None
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
@ -457,16 +467,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_ENABLED = True
def hijack_progress(server: ExecutorToClientProgress):
def hook(value: float, total: float, preview_image):
interruption.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
set_progress_bar_global_hook(hook)
def set_progress_bar_enabled(enabled): def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED global PROGRESS_BAR_ENABLED
PROGRESS_BAR_ENABLED = enabled PROGRESS_BAR_ENABLED = enabled
PROGRESS_BAR_HOOK = None
def get_progress_bar_enabled() -> bool:
return PROGRESS_BAR_ENABLED
def set_progress_bar_global_hook(function): def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function PROGRESS_BAR_HOOK = function
class ProgressBar: class ProgressBar:
def __init__(self, total: float): def __init__(self, total: float):
global PROGRESS_BAR_HOOK global PROGRESS_BAR_HOOK

View File

@ -31,4 +31,5 @@ aio-pika
pyjwt[crypto] pyjwt[crypto]
kornia>=0.7.1 kornia>=0.7.1
mpmath>=1.0,!=1.4.0a0 mpmath>=1.0,!=1.4.0a0
huggingface_hub huggingface_hub
lazy-object-proxy