Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
ltdrdata 2023-04-16 19:32:28 +09:00 committed by GitHub
commit f2c0af4c90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1991 additions and 129 deletions

View File

@ -32,14 +32,28 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Shortcuts ## Shortcuts
- **Ctrl + A** select all nodes
- **Ctrl + M** mute/unmute selected nodes | Keybind | Explanation |
- **Delete** or **Backspace** delete selected nodes | - | - |
- **Space** Holding space key while moving the cursor moves the canvas around. It works when holding the mouse button down so it is easier to connect different nodes when the canvas gets too large. | Ctrl + Enter | Queue up current graph for generation |
- **Ctrl/Shift + Click** Add clicked node to selection. | Ctrl + Shift + Enter | Queue up current graph as first for generation |
- **Ctrl + C/Ctrl + V** - Copy and paste selected nodes, without maintaining the connection to the outputs of unselected nodes. | Ctrl + S | Save workflow |
- **Ctrl + C/Ctrl + Shift + V** - Copy and paste selected nodes, and maintaining the connection from the outputs of unselected nodes to the inputs of the newly pasted nodes. | Ctrl + O | Load workflow |
- Holding **Shift** and drag selected nodes - Move multiple selected nodes at the same time. | Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
Ctrl can also be replaced with Cmd instead for MacOS users
# Installing # Installing

View File

@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
import model_management from comfy import model_management
from . import tomesd from . import tomesd

View File

@ -7,7 +7,7 @@ from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management from comfy import model_management
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers

View File

@ -24,7 +24,7 @@ except ImportError:
from torch import Tensor from torch import Tensor
from typing import List from typing import List
import model_management from comfy import model_management
def dynamic_slice( def dynamic_slice(
x: Tensor, x: Tensor,

View File

@ -307,6 +307,15 @@ def should_use_fp16():
return True return True
def soft_empty_cache():
global xpu_available
if xpu_available:
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
import threading import threading

View File

@ -3,7 +3,7 @@ from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import contextlib import contextlib
import model_management from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps from .ldm.modules.diffusionmodules.util import make_ddim_timesteps

View File

@ -4,7 +4,7 @@ import copy
import sd1_clip import sd1_clip
import sd2_clip import sd2_clip
import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL
import yaml import yaml
@ -421,10 +421,12 @@ class CLIP:
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.layer_idx = layer_idx self.layer_idx = layer_idx
def encode(self, text): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens):
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
tokens = self.tokenizer.tokenize_with_weights(text)
try: try:
self.patcher.patch_model() self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens) cond = self.cond_stage_model.encode_token_weights(tokens)
@ -434,6 +436,10 @@ class CLIP:
raise e raise e
return cond return cond
def encode(self, text):
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
class VAE: class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
if config is None: if config is None:

View File

@ -2,6 +2,8 @@ import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
import torch import torch
import traceback
import zipfile
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
@ -170,6 +172,26 @@ def unescape_important(text):
text = text.replace("\0\2", "(") text = text.replace("\0\2", "(")
return text return text
def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip:
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
names.reverse()
for n in names:
with myzip.open(n) as myfile:
data = myfile.read()
number = len(data) // 4
length_embed = 1024 #sd2.x
if number < 768:
continue
if number % 768 == 0:
length_embed = 768 #sd1.x
num_embeds = number // length_embed
embed = torch.frombuffer(data, dtype=torch.float)
out = embed.reshape((num_embeds, length_embed)).clone()
del embed
return out
def load_embed(embedding_name, embedding_directory): def load_embed(embedding_name, embedding_directory):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
@ -194,19 +216,33 @@ def load_embed(embedding_name, embedding_directory):
embed_path = valid_file embed_path = valid_file
if embed_path.lower().endswith(".safetensors"): embed_out = None
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu") try:
else: if embed_path.lower().endswith(".safetensors"):
if 'weights_only' in torch.load.__code__.co_varnames: import safetensors.torch
embed = torch.load(embed_path, weights_only=True, map_location="cpu") embed = safetensors.torch.load_file(embed_path, device="cpu")
else: else:
embed = torch.load(embed_path, map_location="cpu") if 'weights_only' in torch.load.__code__.co_varnames:
if 'string_to_param' in embed: try:
values = embed['string_to_param'].values() embed = torch.load(embed_path, weights_only=True, map_location="cpu")
else: except:
values = embed.values() embed_out = safe_load_embed_zip(embed_path)
return next(iter(values)) else:
embed = torch.load(embed_path, map_location="cpu")
except Exception as e:
print(traceback.format_exc())
print()
print("error loading embedding, skipping loading:", embedding_name)
return None
if embed_out is None:
if 'string_to_param' in embed:
values = embed['string_to_param'].values()
else:
values = embed.values()
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
@ -224,60 +260,97 @@ class SD1Tokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory self.embedding_directory = embedding_directory
self.max_word_length = 8 self.max_word_length = 8
self.embedding_identifier = "embedding:"
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
def tokenize_with_weights(self, text):
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
#tokenize words
tokens = [] tokens = []
for t in parsed_weights: for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
while len(to_tokenize) > 0: to_tokenize = [x for x in to_tokenize if x != ""]
word = to_tokenize.pop(0) for word in to_tokenize:
temp_tokens = [] #if we find an embedding, deal with the embedding
embedding_identifier = "embedding:" if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
if word.startswith(embedding_identifier) and self.embedding_directory is not None: embedding_name = word[len(self.embedding_identifier):].strip('\n')
embedding_name = word[len(embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name)
embed = load_embed(embedding_name, self.embedding_directory)
if embed is None: if embed is None:
stripped = embedding_name.strip(',') print(f"warning, embedding:{embedding_name} does not exist, ignoring")
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
if embed is not None:
to_tokenize.insert(0, embedding_name[len(stripped):])
if embed is not None:
if len(embed.shape) == 1:
temp_tokens += [(embed, t[1])]
else:
for x in range(embed.shape[0]):
temp_tokens += [(embed[x], t[1])]
else: else:
print("warning, embedding:{} does not exist, ignoring".format(embedding_name)) if len(embed.shape) == 1:
elif len(word) > 0: tokens.append([(embed, weight)])
tt = self.tokenizer(word)["input_ids"][1:-1] else:
for x in tt: tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
temp_tokens += [(x, t[1])] #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) if leftover != "":
word = leftover
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
#try not to split words in different sections #reshape token array to CLIP input size
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): batched_tokens = []
for x in range(tokens_left): batch = [(self.start_token, 1.0, 0)]
tokens += [(self.end_token, 1.0)] batched_tokens.append(batch)
tokens += temp_tokens for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
out_tokens = [] while len(t_group) > 0:
for x in range(0, len(tokens), self.max_tokens_per_section): if len(t_group) + len(batch) > self.max_length - 1:
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))] remaining_length = self.max_length - len(batch) - 1
o_token += [(self.end_token, 1.0)] #break word in two and add end token
if self.pad_with_end: if is_large:
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token)) batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
else: batch.append((self.end_token, 1.0, 0))
o_token +=[(0, 1.0)] * (self.max_length - len(o_token)) t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
out_tokens += [o_token] #fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
return out_tokens
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))

View File

@ -1,4 +1,4 @@
import sd1_clip from comfy import sd1_clip
import torch import torch
import os import os

View File

@ -187,13 +187,12 @@ class MaskComposite:
source_portion = source[:visible_height, :visible_width] source_portion = source[:visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right] destination_portion = destination[top:bottom, left:right]
match operation: if operation == "multiply":
case "multiply": output[top:bottom, left:right] = destination_portion * source_portion
output[top:bottom, left:right] = destination_portion * source_portion elif operation == "add":
case "add": output[top:bottom, left:right] = destination_portion + source_portion
output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract":
case "subtract": output[top:bottom, left:right] = destination_portion - source_portion
output[top:bottom, left:right] = destination_portion - source_portion
output = torch.clamp(output, 0.0, 1.0) output = torch.clamp(output, 0.0, 1.0)

View File

@ -1,6 +1,6 @@
import os import os
from comfy_extras.chainner_models import model_loading from comfy_extras.chainner_models import model_loading
import model_management from comfy import model_management
import torch import torch
import comfy.utils import comfy.utils
import folder_paths import folder_paths

View File

@ -10,6 +10,8 @@ import gc
import torch import torch
import nodes import nodes
import comfy.model_management
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
@ -202,10 +204,7 @@ class PromptExecutor:
self.server.send_sync("executing", { "node": None }, self.server.client_id) self.server.send_sync("executing", { "node": None }, self.server.client_id)
gc.collect() gc.collect()
if torch.cuda.is_available(): comfy.model_management.soft_empty_cache()
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def validate_inputs(prompt, item): def validate_inputs(prompt, item):

View File

@ -18,6 +18,6 @@ a111:
#other_ui: #other_ui:
# base_path: path/to/ui # base_path: path/to/ui
# checkpoints: models/checkpoints # checkpoints: models/checkpoints
# custom_nodes: path/custom_nodes

View File

@ -12,8 +12,8 @@ except:
folder_names_and_paths = {} folder_names_and_paths = {}
base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
@ -28,6 +28,9 @@ folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")],
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")

15
main.py
View File

@ -81,6 +81,14 @@ if __name__ == "__main__":
server = server.PromptServer(loop) server = server.PromptServer(loop)
q = execution.PromptQueue(server) q = execution.PromptQueue(server)
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
init_custom_nodes() init_custom_nodes()
server.add_routes() server.add_routes()
hijack_progress(server) hijack_progress(server)
@ -91,13 +99,6 @@ if __name__ == "__main__":
dont_print = args.dont_print_server dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)

View File

@ -21,16 +21,16 @@ import comfy.utils
import comfy.clip_vision import comfy.clip_vision
import model_management import comfy.model_management
import importlib import importlib
import folder_paths import folder_paths
def before_node_execution(): def before_node_execution():
model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True): def interrupt_processing(value=True):
model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=8192
@ -241,7 +241,7 @@ class DiffusersLoader:
model_path = os.path.join(search_path, model_path) model_path = os.path.join(search_path, model_path)
break break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
@ -756,7 +756,7 @@ class SetLatentNoiseMask:
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = None noise_mask = None
device = model_management.get_torch_device() device = comfy.model_management.get_torch_device()
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
@ -772,7 +772,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)
real_model = None real_model = None
model_management.load_model_gpu(model) comfy.model_management.load_model_gpu(model)
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)
@ -802,7 +802,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
control_net_models = [] control_net_models = []
for x in control_nets: for x in control_nets:
control_net_models += x.get_control_models() control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models) comfy.model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
@ -1255,15 +1255,16 @@ def load_custom_node(module_path):
print(f"Cannot import {module_path} module for custom nodes:", e) print(f"Cannot import {module_path} module for custom nodes:", e)
def load_custom_nodes(): def load_custom_nodes():
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
possible_modules = os.listdir(CUSTOM_NODE_PATH) for custom_node_path in node_paths:
if "__pycache__" in possible_modules: possible_modules = os.listdir(custom_node_path)
possible_modules.remove("__pycache__") if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
for possible_module in possible_modules: for possible_module in possible_modules:
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module) module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path) load_custom_node(module_path)
def init_custom_nodes(): def init_custom_nodes():
load_custom_nodes() load_custom_nodes()

View File

@ -122,7 +122,7 @@
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n",

View File

@ -0,0 +1,76 @@
import { app } from "/scripts/app.js";
const id = "Comfy.Keybinds";
app.registerExtension({
name: id,
init() {
const keybindListener = function(event) {
const target = event.composedPath()[0];
if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") {
return;
}
const modifierPressed = event.ctrlKey || event.metaKey;
// Queue prompt using ctrl or command + enter
if (modifierPressed && (event.key === "Enter" || event.keyCode === 13 || event.keyCode === 10)) {
app.queuePrompt(event.shiftKey ? -1 : 0);
return;
}
const modifierKeyIdMap = {
"s": "#comfy-save-button",
83: "#comfy-save-button",
"o": "#comfy-file-input",
79: "#comfy-file-input",
"Backspace": "#comfy-clear-button",
8: "#comfy-clear-button",
"Delete": "#comfy-clear-button",
46: "#comfy-clear-button",
"d": "#comfy-load-default-button",
68: "#comfy-load-default-button",
};
const modifierKeybindId = modifierKeyIdMap[event.key] || modifierKeyIdMap[event.keyCode];
if (modifierPressed && modifierKeybindId) {
event.preventDefault();
const elem = document.querySelector(modifierKeybindId);
elem.click();
return;
}
// Finished Handling all modifier keybinds, now handle the rest
if (event.ctrlKey || event.altKey || event.metaKey) {
return;
}
// Close out of modals using escape
if (event.key === "Escape" || event.keyCode === 27) {
const modals = document.querySelectorAll(".comfy-modal");
const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none");
if (modal) {
modal.style.display = "none";
}
}
const keyIdMap = {
"q": "#comfy-view-queue-button",
81: "#comfy-view-queue-button",
"h": "#comfy-view-history-button",
72: "#comfy-view-history-button",
"r": "#comfy-refresh-button",
82: "#comfy-refresh-button",
};
const buttonId = keyIdMap[event.key] || keyIdMap[event.keyCode];
if (buttonId) {
const button = document.querySelector(buttonId);
button.click();
}
}
window.addEventListener("keydown", keybindListener, true);
}
});

View File

@ -0,0 +1,41 @@
import {app} from "../../scripts/app.js";
import {ComfyWidgets} from "../../scripts/widgets.js";
// Node that add notes to your project
app.registerExtension({
name: "Comfy.NoteNode",
registerCustomNodes() {
class NoteNode {
color=LGraphCanvas.node_colors.yellow.color;
bgcolor=LGraphCanvas.node_colors.yellow.bgcolor;
groupcolor = LGraphCanvas.node_colors.yellow.groupcolor;
constructor() {
if (!this.properties) {
this.properties = {};
this.properties.text="";
}
ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app)
this.serialize_widgets = true;
this.isVirtualNode = true;
}
}
// Load default visibility
LiteGraph.registerNodeType(
"Note",
Object.assign(NoteNode, {
title_mode: LiteGraph.NORMAL_TITLE,
title: "Note",
collapsable: true,
})
);
NoteNode.category = "utils";
},
});

View File

@ -159,9 +159,11 @@ app.registerExtension({
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined; const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined;
const input = this.inputs[slot]; const input = this.inputs[slot];
if (!input.widget || !input[ignoreDblClick])// Not a widget input or already handled input if (!input.widget || !input[ignoreDblClick]) {
{ // Not a widget input or already handled input
if (!(input.type in ComfyWidgets)) return r;//also Not a ComfyWidgets input (do nothing) if (!(input.type in ComfyWidgets) && !(input.widget.config?.[0] instanceof Array)) {
return r; //also Not a ComfyWidgets input or combo (do nothing)
}
} }
// Create a primitive node // Create a primitive node
@ -333,7 +335,20 @@ app.registerExtension({
const config1 = this.outputs[0].widget.config; const config1 = this.outputs[0].widget.config;
const config2 = input.widget.config; const config2 = input.widget.config;
if (config1[0] !== config2[0]) return false; if (config1[0] instanceof Array) {
// These checks shouldnt actually be necessary as the types should match
// but double checking doesn't hurt
// New input isnt a combo
if (!(config2[0] instanceof Array)) return false;
// New imput combo has a different size
if (config1[0].length !== config2[0].length) return false;
// New input combo has different elements
if (config1[0].find((v, i) => config2[0][i] !== v)) return false;
} else if (config1[0] !== config2[0]) {
// Configs dont match
return false;
}
for (const k in config1[1]) { for (const k in config1[1]) {
if (k !== "default") { if (k !== "default") {

View File

@ -4,27 +4,48 @@ import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js"; import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js";
class ComfyApp { /**
/** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension
* List of {number, batchCount} entries to queue */
export class ComfyApp {
/**
* List of entries to queue
* @type {{number: number, batchCount: number}[]}
*/ */
#queueItems = []; #queueItems = [];
/** /**
* If the queue is currently being processed * If the queue is currently being processed
* @type {boolean}
*/ */
#processingQueue = false; #processingQueue = false;
constructor() { constructor() {
this.ui = new ComfyUI(this); this.ui = new ComfyUI(this);
/**
* List of extensions that are registered with the app
* @type {ComfyExtension[]}
*/
this.extensions = []; this.extensions = [];
/**
* Stores the execution output data for each node
* @type {Record<string, any>}
*/
this.nodeOutputs = {}; this.nodeOutputs = {};
/**
* If the shift key on the keyboard is pressed
* @type {boolean}
*/
this.shiftDown = false; this.shiftDown = false;
} }
/** /**
* Invoke an extension callback * Invoke an extension callback
* @param {string} method The extension callback to execute * @param {keyof ComfyExtension} method The extension callback to execute
* @param {...any} args Any arguments to pass to the callback * @param {any[]} args Any arguments to pass to the callback
* @returns * @returns
*/ */
#invokeExtensions(method, ...args) { #invokeExtensions(method, ...args) {
@ -691,11 +712,6 @@ class ComfyApp {
#addKeyboardHandler() { #addKeyboardHandler() {
window.addEventListener("keydown", (e) => { window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey; this.shiftDown = e.shiftKey;
// Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0);
}
}); });
window.addEventListener("keyup", (e) => { window.addEventListener("keyup", (e) => {
this.shiftDown = e.shiftKey; this.shiftDown = e.shiftKey;
@ -1120,6 +1136,10 @@ class ComfyApp {
} }
} }
/**
* Registers a Comfy web extension with the app
* @param {ComfyExtension} extension
*/
registerExtension(extension) { registerExtension(extension) {
if (!extension.name) { if (!extension.name) {
throw new Error("Extensions must have a 'name' property."); throw new Error("Extensions must have a 'name' property.");

View File

@ -131,6 +131,7 @@ export async function importA1111(graph, parameters) {
} }
function replaceEmbeddings(text) { function replaceEmbeddings(text) {
if(!embeddings.length) return text;
return text.replaceAll( return text.replaceAll(
new RegExp( new RegExp(
"\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b", "\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b",

View File

@ -431,7 +431,15 @@ export class ComfyUI {
defaultValue: true, defaultValue: true,
}); });
const promptFilename = this.settings.addSetting({
id: "Comfy.PromptFilename",
name: "Prompt for filename when saving workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input",
type: "file", type: "file",
accept: ".json,image/png", accept: ".json,image/png",
style: { display: "none" }, style: { display: "none" },
@ -448,6 +456,7 @@ export class ComfyUI {
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
]), ]),
$el("button.comfy-queue-btn", { $el("button.comfy-queue-btn", {
id: "queue-button",
textContent: "Queue Prompt", textContent: "Queue Prompt",
onclick: () => app.queuePrompt(0, this.batchCount), onclick: () => app.queuePrompt(0, this.batchCount),
}), }),
@ -496,9 +505,10 @@ export class ComfyUI {
]), ]),
]), ]),
$el("div.comfy-menu-btns", [ $el("div.comfy-menu-btns", [
$el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }), $el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
$el("button", { $el("button", {
$: (b) => (this.queue.button = b), $: (b) => (this.queue.button = b),
id: "comfy-view-queue-button",
textContent: "View Queue", textContent: "View Queue",
onclick: () => { onclick: () => {
this.history.hide(); this.history.hide();
@ -507,6 +517,7 @@ export class ComfyUI {
}), }),
$el("button", { $el("button", {
$: (b) => (this.history.button = b), $: (b) => (this.history.button = b),
id: "comfy-view-history-button",
textContent: "View History", textContent: "View History",
onclick: () => { onclick: () => {
this.queue.hide(); this.queue.hide();
@ -517,14 +528,23 @@ export class ComfyUI {
this.queue.element, this.queue.element,
this.history.element, this.history.element,
$el("button", { $el("button", {
id: "comfy-save-button",
textContent: "Save", textContent: "Save",
onclick: () => { onclick: () => {
let filename = "workflow.json";
if (promptFilename.value) {
filename = prompt("Save workflow as:", filename);
if (!filename) return;
if (!filename.toLowerCase().endsWith(".json")) {
filename += ".json";
}
}
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
const blob = new Blob([json], { type: "application/json" }); const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
const a = $el("a", { const a = $el("a", {
href: url, href: url,
download: "workflow.json", download: filename,
style: { display: "none" }, style: { display: "none" },
parent: document.body, parent: document.body,
}); });
@ -535,15 +555,15 @@ export class ComfyUI {
}, 0); }, 0);
}, },
}), }),
$el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => { $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
if (!confirmClear.value || confirm("Clear workflow?")) { if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean(); app.clean();
app.graph.clear(); app.graph.clear();
} }
}}), }}),
$el("button", { textContent: "Load Default", onclick: () => { $el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
if (!confirmClear.value || confirm("Load default workflow?")) { if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData() app.loadGraphData()
} }

78
web/types/comfy.d.ts vendored Normal file
View File

@ -0,0 +1,78 @@
import { LGraphNode, IWidget } from "./litegraph";
import { ComfyApp } from "/scripts/app";
export interface ComfyExtension {
/**
* The name of the extension
*/
name: string;
/**
* Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added
* @param app The ComfyUI app instance
*/
init(app: ComfyApp): Promise<void>;
/**
* Allows any additonal setup, called after the application is fully set up and running
* @param app The ComfyUI app instance
*/
setup(app: ComfyApp): Promise<void>;
/**
* Called before nodes are registered with the graph
* @param defs The collection of node definitions, add custom ones or edit existing ones
* @param app The ComfyUI app instance
*/
addCustomNodeDefs(defs: Record<string, ComfyObjectInfo>, app: ComfyApp): Promise<void>;
/**
* Allows the extension to add custom widgets
* @param app The ComfyUI app instance
* @returns An array of {[widget name]: widget data}
*/
getCustomWidgets(
app: ComfyApp
): Promise<
Array<
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
>
>;
/**
* Allows the extension to add additional handling to the node before it is registered with LGraph
* @param nodeType The node class (not an instance)
* @param nodeData The original node object info config object
* @param app The ComfyUI app instance
*/
beforeRegisterNodeDef(nodeType: typeof LGraphNode, nodeData: ComfyObjectInfo, app: ComfyApp): Promise<void>;
/**
* Allows the extension to register additional nodes with LGraph after standard nodes are added
* @param app The ComfyUI app instance
*/
registerCustomNodes(app: ComfyApp): Promise<void>;
/**
* Allows the extension to modify a node that has been reloaded onto the graph.
* If you break something in the backend and want to patch workflows in the frontend
* This is the place to do this
* @param node The node that has been loaded
* @param app The ComfyUI app instance
*/
loadedGraphNode(node: LGraphNode, app: ComfyApp);
/**
* Allows the extension to run code after the constructor of the node
* @param node The node that has been created
* @param app The ComfyUI app instance
*/
nodeCreated(node: LGraphNode, app: ComfyApp);
}
export type ComfyObjectInfo = {
name: string;
display_name?: string;
description?: string;
category: string;
input?: {
required?: Record<string, ComfyObjectInfoConfig>;
optional?: Record<string, ComfyObjectInfoConfig>;
};
output?: string[];
output_name: string[];
};
export type ComfyObjectInfoConfig = [string | any[]] | [string | any[], any];

1506
web/types/litegraph.d.ts vendored Normal file

File diff suppressed because it is too large Load Diff