mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
Improvements to node loading, node API, folder paths and progress
- Improve node loading order. It now occurs "as late as possible". Configuration should be exposed as per the README. - Added methods to specify custom folders and models used in examples more robustly for custom nodes. - Downloading models can now be gracefully interrupted. - Progress notifications are now sent over the network for distributed ComfyUI operations. - Python objects have been moved around to prevent less transitive package importing issues.
This commit is contained in:
parent
3ccbda36da
commit
341c9f2e90
4
.editorconfig
Normal file
4
.editorconfig
Normal file
@ -0,0 +1,4 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
max_line_length = off
|
||||
@ -11,7 +11,9 @@ import typing
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing_extensions import TypedDict
|
||||
import torch
|
||||
import lazy_object_proxy
|
||||
|
||||
from .. import interruption
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
@ -21,7 +23,7 @@ from ..nodes.package import import_all_nodes_in_workspace
|
||||
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
nodes: ExportedNodes = import_all_nodes_in_workspace()
|
||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
|
||||
if extra_data is None:
|
||||
@ -80,16 +82,16 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
results = []
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)())
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
@ -185,7 +187,7 @@ def recursive_execute(server: ExecutorToClientProgress,
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
|
||||
server.client_id)
|
||||
except model_management.InterruptProcessingException as iex:
|
||||
except interruption.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
@ -327,7 +329,7 @@ class PromptExecutor:
|
||||
|
||||
# First, send back the status to the frontend depending
|
||||
# on the exception type
|
||||
if isinstance(ex, model_management.InterruptProcessingException):
|
||||
if isinstance(ex, interruption.InterruptProcessingException):
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
@ -367,7 +369,7 @@ class PromptExecutor:
|
||||
execute_outputs = []
|
||||
if extra_data is None:
|
||||
extra_data = {}
|
||||
model_management.interrupt_current_processing(False)
|
||||
interruption.interrupt_current_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
self.server.client_id = extra_data["client_id"]
|
||||
|
||||
@ -1,17 +1,79 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import posixpath
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
||||
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||
|
||||
folder_names_and_paths = {}
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FolderPathsTuple:
|
||||
folder_name: str
|
||||
paths: List[str] = dataclasses.field(default_factory=list)
|
||||
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
||||
|
||||
def __getitem__(self, item: Any):
|
||||
if item == 0:
|
||||
return self.paths
|
||||
if item == 1:
|
||||
return self.supported_extensions
|
||||
else:
|
||||
raise RuntimeError("unsupported tuple index")
|
||||
|
||||
def __add__(self, other: "FolderPathsTuple"):
|
||||
assert self.folder_name == other.folder_name
|
||||
new_paths = list(frozenset(self.paths + other.paths))
|
||||
new_supported_extensions = self.supported_extensions | other.supported_extensions
|
||||
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
|
||||
|
||||
def __iter__(self) -> Iterator[Sequence[str]]:
|
||||
yield self.paths
|
||||
yield self.supported_extensions
|
||||
|
||||
|
||||
class FolderNames:
|
||||
def __init__(self):
|
||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||
|
||||
def __getitem__(self, item) -> FolderPathsTuple:
|
||||
if not isinstance(item, str):
|
||||
raise RuntimeError("expected folder path")
|
||||
if item not in self.contents:
|
||||
self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set())
|
||||
return self.contents[item]
|
||||
|
||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||
assert isinstance(key, str)
|
||||
if isinstance(value, tuple):
|
||||
paths, supported_extensions = value
|
||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||
if key in self.contents:
|
||||
value = self.contents[key] + value
|
||||
self.contents[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def items(self):
|
||||
return self.contents.items()
|
||||
|
||||
def values(self):
|
||||
return self.contents.values()
|
||||
|
||||
def keys(self):
|
||||
return self.contents.keys()
|
||||
|
||||
|
||||
folder_names_and_paths = FolderNames()
|
||||
|
||||
# todo: this should be initialized elsewhere
|
||||
if 'main.py' in sys.argv:
|
||||
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
|
||||
elif args.cwd is not None:
|
||||
@ -25,31 +87,31 @@ elif args.cwd is not None:
|
||||
else:
|
||||
base_path = os.getcwd()
|
||||
models_dir = os.path.join(base_path, "models")
|
||||
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
|
||||
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], set([".yaml"]))
|
||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
|
||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
|
||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["unet"] = FolderPathsTuple("unet", [os.path.join(models_dir, "unet")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"})
|
||||
folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(base_path, "custom_nodes")], set())
|
||||
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
user_directory = os.path.join(base_path, "user")
|
||||
|
||||
filename_list_cache = {}
|
||||
_filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
try:
|
||||
@ -57,32 +119,38 @@ if not os.path.exists(input_directory):
|
||||
except:
|
||||
logging.error("Failed to create input directory")
|
||||
|
||||
|
||||
def set_output_directory(output_dir):
|
||||
global output_directory
|
||||
output_directory = output_dir
|
||||
|
||||
|
||||
def set_temp_directory(temp_dir):
|
||||
global temp_directory
|
||||
temp_directory = temp_dir
|
||||
|
||||
|
||||
def set_input_directory(input_dir):
|
||||
global input_directory
|
||||
input_directory = input_dir
|
||||
|
||||
|
||||
def get_output_directory():
|
||||
global output_directory
|
||||
return output_directory
|
||||
|
||||
|
||||
def get_temp_directory():
|
||||
global temp_directory
|
||||
return temp_directory
|
||||
|
||||
|
||||
def get_input_directory():
|
||||
global input_directory
|
||||
return input_directory
|
||||
|
||||
|
||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||
# NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||
def get_directory_by_type(type_name):
|
||||
if type_name == "output":
|
||||
return get_output_directory()
|
||||
@ -133,17 +201,29 @@ def exists_annotated_filepath(name):
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None):
|
||||
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Registers a model path for the given canonical name.
|
||||
:param folder_name: the folder name
|
||||
:param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as
|
||||
a subpath to the default models directory
|
||||
:return: the folder path
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
if full_folder_path is None:
|
||||
full_folder_path = os.path.join(models_dir, folder_name)
|
||||
if folder_name in folder_names_and_paths and full_folder_path not in folder_names_and_paths[folder_name][0]:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
else:
|
||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||
|
||||
def get_folder_paths(folder_name):
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
folder_path = folder_names_and_paths[folder_name]
|
||||
if full_folder_path not in folder_path.paths:
|
||||
folder_path.paths.append(full_folder_path)
|
||||
|
||||
invalidate_cache(folder_name)
|
||||
return full_folder_path
|
||||
|
||||
|
||||
def get_folder_paths(folder_name) -> List[str]:
|
||||
return folder_names_and_paths[folder_name].paths[:]
|
||||
|
||||
|
||||
def recursive_search(directory, excluded_dir_names=None):
|
||||
if not os.path.isdir(directory):
|
||||
@ -176,24 +256,41 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
continue
|
||||
return result, dirs
|
||||
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||
|
||||
|
||||
|
||||
def get_full_path(folder_name, filename):
|
||||
"""
|
||||
Gets the path to a filename inside a folder.
|
||||
|
||||
Works with untrusted filenames.
|
||||
:param folder_name:
|
||||
:param filename:
|
||||
:return:
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
if folder_name not in folder_names_and_paths:
|
||||
return None
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
filename = os.path.relpath(os.path.join("/", filename), "/")
|
||||
for x in folders[0]:
|
||||
full_path = os.path.join(x, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
folders = folder_names_and_paths[folder_name].paths
|
||||
filename_split = os.path.split(filename)
|
||||
|
||||
trusted_paths = []
|
||||
for folder in folders:
|
||||
folder_split = os.path.split(folder)
|
||||
abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split))
|
||||
abs_folder_path = os.path.abspath(folder)
|
||||
if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path:
|
||||
trusted_paths.append(abs_file_path)
|
||||
else:
|
||||
logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}")
|
||||
|
||||
for trusted_path in trusted_paths:
|
||||
if os.path.isfile(trusted_path):
|
||||
return trusted_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_filename_list_(folder_name):
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
@ -204,14 +301,15 @@ def get_filename_list_(folder_name):
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||
|
||||
|
||||
def cached_filename_list_(folder_name):
|
||||
global filename_list_cache
|
||||
global _filename_list_cache
|
||||
global folder_names_and_paths
|
||||
if folder_name not in filename_list_cache:
|
||||
if folder_name not in _filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
out = _filename_list_cache[folder_name]
|
||||
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
@ -227,14 +325,16 @@ def cached_filename_list_(folder_name):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
global filename_list_cache
|
||||
filename_list_cache[folder_name] = out
|
||||
global _filename_list_cache
|
||||
_filename_list_cache[folder_name] = out
|
||||
return list(out[0])
|
||||
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
@ -255,7 +355,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
full_output_folder = str(os.path.join(output_dir, subfolder))
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
|
||||
@ -276,8 +376,14 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
|
||||
|
||||
def create_directories():
|
||||
for _, (paths, _) in folder_names_and_paths.items():
|
||||
default_path = paths[0]
|
||||
os.makedirs(default_path, exist_ok=True)
|
||||
# all configured paths should be created
|
||||
for folder_path_spec in folder_names_and_paths.values():
|
||||
for path in folder_path_spec.paths:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
for path in (temp_directory, input_directory, output_directory, user_directory):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def invalidate_cache(folder_name):
|
||||
global _filename_list_cache
|
||||
_filename_list_cache.pop(folder_name, None)
|
||||
|
||||
@ -1,94 +1,26 @@
|
||||
from .. import options
|
||||
# Suppress warnings during import
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
||||
options.enable_args_parsing()
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
from ..cmd import cuda_malloc
|
||||
from ..cmd import folder_paths
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
def execute_script(script_path):
|
||||
module_name = os.path.splitext(script_path)[0]
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_prestartup_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
||||
continue
|
||||
|
||||
script_path = os.path.join(module_path, "prestartup_script.py")
|
||||
if os.path.exists(script_path):
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
if len(node_prestartup_times) > 0:
|
||||
logging.info("\nPrestartup times for custom nodes:")
|
||||
for n in sorted(node_prestartup_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
else:
|
||||
import_message = " (PRESTARTUP FAILED)"
|
||||
logging.info("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
|
||||
|
||||
execute_prestartup_script()
|
||||
|
||||
# Main code
|
||||
import asyncio
|
||||
import itertools
|
||||
import shutil
|
||||
import threading
|
||||
import gc
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
if os.name == "nt":
|
||||
import logging
|
||||
|
||||
logging.getLogger("xformers").addFilter(
|
||||
lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to:", args.cuda_device)
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
from .. import utils
|
||||
import time
|
||||
|
||||
from comfy.utils import hijack_progress
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .main_pre import args
|
||||
from .. import model_management
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
from ..cmd import cuda_malloc
|
||||
from ..cmd import folder_paths
|
||||
from ..cmd import server as server_module
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus
|
||||
from .. import model_management
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
|
||||
|
||||
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
@ -120,7 +52,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
completed=e.success,
|
||||
messages=e.status_messages))
|
||||
if _server.client_id is not None:
|
||||
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
|
||||
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
@ -152,18 +84,6 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
|
||||
def hijack_progress(server: ExecutorToClientProgress):
|
||||
def hook(value: float, total: float, preview_image):
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
|
||||
server.send_sync("progress", progress, server.client_id)
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||
|
||||
utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
def cleanup_temp():
|
||||
try:
|
||||
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
@ -194,6 +114,18 @@ async def main():
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
cleanup_temp()
|
||||
|
||||
# configure extra model paths earlier
|
||||
try:
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
# create the default directories if we're instructed to, then exit
|
||||
# or, if it's a windows standalone build, the single .exe file should have its side-by-side directories always created
|
||||
if args.create_directories:
|
||||
@ -208,18 +140,6 @@ async def main():
|
||||
except:
|
||||
pass
|
||||
|
||||
# configure extra model paths earlier
|
||||
try:
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
server = server_module.PromptServer(loop)
|
||||
if args.external_address is not None:
|
||||
@ -246,7 +166,6 @@ async def main():
|
||||
q = PromptQueue(server)
|
||||
server.prompt_queue = q
|
||||
|
||||
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
cuda_malloc_warning()
|
||||
@ -293,7 +212,7 @@ async def main():
|
||||
server.port = args.port
|
||||
try:
|
||||
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
|
||||
call_on_start=call_on_start)
|
||||
call_on_start=call_on_start)
|
||||
except asyncio.CancelledError:
|
||||
if distributed:
|
||||
await q.close()
|
||||
|
||||
23
comfy/cmd/main_pre.py
Normal file
23
comfy/cmd/main_pre.py
Normal file
@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
from .. import options
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
options.enable_args_parsing()
|
||||
if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to:", args.cuda_device)
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
__all__ = ["args"]
|
||||
@ -23,6 +23,7 @@ import aiofiles
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
import comfy.interruption
|
||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
|
||||
from ..cmd import execution
|
||||
from ..cmd import folder_paths
|
||||
@ -515,7 +516,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
@routes.post("/interrupt")
|
||||
async def post_interrupt(request):
|
||||
model_management.interrupt_current_processing()
|
||||
comfy.interruption.interrupt_current_processing()
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/free")
|
||||
|
||||
@ -10,6 +10,7 @@ from aio_pika.patterns import RPC
|
||||
|
||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
|
||||
from ..component_model.queue_types import BinaryEventTypes
|
||||
from ..utils import hijack_progress
|
||||
|
||||
|
||||
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
|
||||
@ -24,7 +25,7 @@ def _get_name(queue_name: str, user_id: str) -> str:
|
||||
|
||||
|
||||
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
|
||||
self._rpc = rpc
|
||||
self._queue_name = queue_name
|
||||
self._loop = loop
|
||||
@ -32,6 +33,8 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
self.client_id = None
|
||||
self.node_id = None
|
||||
self.last_node_id = None
|
||||
if receive_all_progress_notifications:
|
||||
hijack_progress(self)
|
||||
|
||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||
|
||||
31
comfy/interruption.py
Normal file
31
comfy/interruption.py
Normal file
@ -0,0 +1,31 @@
|
||||
import threading
|
||||
|
||||
_interrupt_processing_mutex = threading.RLock()
|
||||
_interrupt_processing = False
|
||||
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def interrupt_current_processing(value=True):
|
||||
global _interrupt_processing
|
||||
global _interrupt_processing_mutex
|
||||
with _interrupt_processing_mutex:
|
||||
_interrupt_processing = value
|
||||
|
||||
|
||||
def processing_interrupted():
|
||||
global _interrupt_processing
|
||||
global _interrupt_processing_mutex
|
||||
with _interrupt_processing_mutex:
|
||||
return _interrupt_processing
|
||||
|
||||
|
||||
def throw_exception_if_processing_interrupted():
|
||||
global _interrupt_processing
|
||||
global _interrupt_processing_mutex
|
||||
with _interrupt_processing_mutex:
|
||||
if _interrupt_processing:
|
||||
_interrupt_processing = False
|
||||
raise InterruptProcessingException()
|
||||
@ -1,15 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from itertools import chain
|
||||
from os.path import join
|
||||
from typing import List, Any, Optional
|
||||
from typing import List, Any, Optional, Union
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from requests import Session
|
||||
|
||||
from .cmd import folder_paths
|
||||
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse
|
||||
from .utils import comfy_tqdm, ProgressBar
|
||||
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
|
||||
from .interruption import InterruptProcessingException
|
||||
from .utils import ProgressBar, comfy_tqdm
|
||||
|
||||
session = Session()
|
||||
|
||||
@ -20,32 +23,40 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]
|
||||
return sorted(list(existing | downloadable))
|
||||
|
||||
|
||||
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> str:
|
||||
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> Optional[str]:
|
||||
path = folder_paths.get_full_path(folder_name, filename)
|
||||
|
||||
if path is None:
|
||||
try:
|
||||
# todo: should this be the first or last path?
|
||||
destination = folder_paths.get_folder_paths(folder_name)[0]
|
||||
known_file = next(f for f in known_files if str(f) == filename)
|
||||
known_file: Optional[HuggingFile | CivitFile] = None
|
||||
for candidate in known_files:
|
||||
if candidate.filename == filename or filename in candidate.alternate_filenames or filename == candidate.save_with_filename:
|
||||
known_file = candidate
|
||||
break
|
||||
if known_file is None:
|
||||
return path
|
||||
with comfy_tqdm():
|
||||
if isinstance(known_file, HuggingFile):
|
||||
save_filename = known_file.save_with_filename or known_file.filename
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
filename=save_filename,
|
||||
local_dir=destination,
|
||||
resume_download=True)
|
||||
else:
|
||||
url: Optional[str] = None
|
||||
save_filename = known_file.save_with_filename or known_file.filename
|
||||
|
||||
if isinstance(known_file, CivitFile):
|
||||
model_info_res = session.get(
|
||||
f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}")
|
||||
model_info: CivitModelsGetResponse = model_info_res.json()
|
||||
for model_version in model_info['modelVersions']:
|
||||
for file in model_version['files']:
|
||||
if file['name'] == filename:
|
||||
url = file['downloadUrl']
|
||||
break
|
||||
if url is not None:
|
||||
|
||||
file: CivitFile_
|
||||
for file in chain.from_iterable(version['files'] for version in model_info['modelVersions']):
|
||||
if file['name'] == filename:
|
||||
url = file['downloadUrl']
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("unknown file type")
|
||||
@ -53,19 +64,29 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
||||
if url is None:
|
||||
logging.warning(f"Could not retrieve file {str(known_file)}")
|
||||
else:
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = ProgressBar(total=total_size)
|
||||
with open(join(destination, filename), "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=512 * 1024):
|
||||
progress_bar.update(len(chunk))
|
||||
file.write(chunk)
|
||||
destination_with_filename = join(destination, save_filename)
|
||||
try:
|
||||
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = ProgressBar(total=total_size)
|
||||
with open(destination_with_filename, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=512 * 1024):
|
||||
progress_bar.update(len(chunk))
|
||||
file.write(chunk)
|
||||
except InterruptProcessingException:
|
||||
os.remove(destination_with_filename)
|
||||
|
||||
path = folder_paths.get_full_path(folder_name, filename)
|
||||
assert path is not None
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.error("Error while trying to download a file", exc_info=exc)
|
||||
finally:
|
||||
# a path was found for any reason, so we should invalidate the cache
|
||||
if path is not None:
|
||||
folder_paths.invalidate_cache(folder_name)
|
||||
return path
|
||||
|
||||
|
||||
@ -118,5 +139,10 @@ KNOWN_CLIP_VISION_MODELS = [
|
||||
|
||||
KNOWN_LORAS = [
|
||||
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
|
||||
# todo: a lot of the slider loras are useful and should also be included
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||
symbol += models
|
||||
folder_paths.invalidate_cache(folder_name)
|
||||
return symbol
|
||||
|
||||
@ -22,6 +22,14 @@ class CivitFile:
|
||||
def __str__(self):
|
||||
return self.filename
|
||||
|
||||
@property
|
||||
def save_with_filename(self):
|
||||
return self.filename
|
||||
|
||||
@property
|
||||
def alternate_filenames(self):
|
||||
return []
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HuggingFile:
|
||||
@ -35,6 +43,8 @@ class HuggingFile:
|
||||
"""
|
||||
repo_id: str
|
||||
filename: str
|
||||
save_with_filename: Optional[str] = None
|
||||
alternate_filenames: List[str] = dataclasses.field(default_factory=list)
|
||||
show_in_ui: Optional[bool] = True
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@ -2,7 +2,7 @@ import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from .cli_args import args
|
||||
from . import utils
|
||||
from . import interruption
|
||||
from threading import RLock
|
||||
|
||||
import torch
|
||||
@ -840,31 +840,14 @@ def unload_all_models():
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
interrupt_processing_mutex = threading.RLock()
|
||||
|
||||
interrupt_processing = False
|
||||
def interrupt_current_processing(value=True):
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
interrupt_processing = value
|
||||
interruption.interrupt_current_processing(value)
|
||||
|
||||
|
||||
def processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
return interrupt_processing
|
||||
interruption.processing_interrupted()
|
||||
|
||||
|
||||
def throw_exception_if_processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
if interrupt_processing:
|
||||
interrupt_processing = False
|
||||
raise InterruptProcessingException()
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
|
||||
@ -25,7 +25,6 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS
|
||||
from ..model_downloader_types import HuggingFile
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
|
||||
|
||||
@ -10,11 +10,8 @@ from functools import reduce
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from comfy_extras import nodes as comfy_extras_nodes
|
||||
from . import base_nodes
|
||||
from .package_typing import ExportedNodes
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
|
||||
|
||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||
|
||||
@ -77,6 +74,10 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import
|
||||
|
||||
|
||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
||||
# now actually import the nodes, to improve control of node loading order
|
||||
from comfy_extras import nodes as comfy_extras_nodes
|
||||
from . import base_nodes
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
# only load these nodes once
|
||||
if len(_comfy_nodes) == 0:
|
||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||
|
||||
@ -1,16 +1,26 @@
|
||||
import os.path
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
from . import checkpoint_pickle
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import checkpoint_pickle, interruption
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
PROGRESS_BAR_HOOK = None
|
||||
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
@ -457,16 +467,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
|
||||
def hijack_progress(server: ExecutorToClientProgress):
|
||||
def hook(value: float, total: float, preview_image):
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
|
||||
server.send_sync("progress", progress, server.client_id)
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||
|
||||
set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
PROGRESS_BAR_ENABLED = enabled
|
||||
|
||||
PROGRESS_BAR_HOOK = None
|
||||
|
||||
def get_progress_bar_enabled() -> bool:
|
||||
return PROGRESS_BAR_ENABLED
|
||||
|
||||
|
||||
def set_progress_bar_global_hook(function):
|
||||
global PROGRESS_BAR_HOOK
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total: float):
|
||||
global PROGRESS_BAR_HOOK
|
||||
|
||||
@ -31,4 +31,5 @@ aio-pika
|
||||
pyjwt[crypto]
|
||||
kornia>=0.7.1
|
||||
mpmath>=1.0,!=1.4.0a0
|
||||
huggingface_hub
|
||||
huggingface_hub
|
||||
lazy-object-proxy
|
||||
Loading…
Reference in New Issue
Block a user