Merge branch 'comfyanonymous:master' into bugfix/extra_data

This commit is contained in:
Dr.Lt.Data 2023-08-31 17:01:09 +09:00 committed by GitHub
commit 7f229b7499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 665 additions and 421 deletions

View File

@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
from .utils import load_torch_file, transformers_convert
import os
import torch
import contextlib
import comfy.ops
import comfy.model_patcher
import comfy.model_management
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
with comfy.ops.use_comfy_ops():
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
@ -27,7 +40,21 @@ class ClipVisionModel():
img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs)
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values)
for k in outputs:
t = outputs[k]
if t is not None:
outputs[k] = t.cpu()
return outputs
def convert_to_transformers(sd, prefix):

View File

@ -1,9 +1,10 @@
import torch
import math
import os
import comfy.utils
import comfy.sd
import comfy.model_management
import comfy.model_detection
import comfy.model_patcher
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@ -129,7 +130,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number):
@ -257,12 +258,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
try:
comfy.utils.set_attr(self.control_model, k, weight)
except:
@ -391,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
control_model = control_model.half()
global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
@ -468,7 +465,7 @@ def load_t2i_adapter(t2i_data):
if len(down_opts) > 0:
use_conv = True
xl = False
if cin == 256:
if cin == 256 or cin == 768:
xl = True
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else:

View File

@ -1,87 +1,36 @@
import json
import os
import yaml
import folder_paths
from comfy.sd import load_checkpoint
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
from . import diffusers_convert
import comfy.sd
def first_file(path, filenames):
for f in filenames:
p = os.path.join(path, f)
if os.path.exists(p):
return p
return None
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
text_encoder_paths = [text_encoder1_path]
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
unet = comfy.sd.load_unet(unet_path)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
clip = None
if output_clip:
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
vae = None
if output_vae:
vae = comfy.sd.VAE(ckpt_path=vae_path)
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)
return (unet, clip, vae)

View File

@ -56,7 +56,18 @@ class Upsample(nn.Module):
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv:
x = self.conv(x)
return x
@ -74,11 +85,10 @@ class Downsample(nn.Module):
stride=2,
padding=0)
def forward(self, x, already_padded=False):
def forward(self, x):
if self.with_conv:
if not already_padded:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module):
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
@ -603,9 +605,6 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
already_padded = True
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
@ -614,8 +613,7 @@ class Encoder(nn.Module):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h, already_padded)
already_padded = False
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h, temb)

View File

@ -118,6 +118,19 @@ def load_lora(lora, to_load):
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
w_norm = lora.get(w_norm_name, None)
b_norm = lora.get(b_norm_name, None)
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)

View File

@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import numpy as np
from enum import Enum
from . import utils
@ -18,8 +19,9 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config, device=device)
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
@ -93,7 +95,11 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict()
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:

View File

@ -1,6 +1,7 @@
import psutil
from enum import Enum
from comfy.cli_args import args
import comfy.utils
import torch
import sys
@ -147,15 +148,27 @@ def is_nvidia():
return True
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16
except:
pass
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
@ -227,6 +240,7 @@ try:
except:
print("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE)
current_loaded_models = []
@ -447,12 +461,8 @@ def vae_offload_device():
return torch.device("cpu")
def vae_dtype():
if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
return torch.bfloat16
else:
return torch.float32
global VAE_DTYPE
return VAE_DTYPE
def get_autocast_device(dev):
if hasattr(dev, 'type'):
@ -637,6 +647,13 @@ def soft_empty_cache():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight
#TODO: might be cleaner to put this somewhere else
import threading

270
comfy/model_patcher.py Normal file
View File

@ -0,0 +1,270 @@
import torch
import copy
import inspect
import comfy.utils
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to

View File

@ -1,7 +1,5 @@
import torch
import contextlib
import copy
import inspect
import math
from comfy import model_management
@ -21,8 +19,10 @@ from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
@ -53,271 +53,6 @@ def load_clip_weights(model, sd):
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return load_model_weights(model, sd)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = comfy.lora.model_lora_keys_unet(model.model)
@ -346,7 +81,7 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
@ -355,7 +90,7 @@ class CLIP:
self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None
def clone(self):
@ -573,7 +308,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually
@ -614,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
class EmptyClass:
pass
model_config = EmptyClass()
model_config.unet_config = unet_config
model_config = comfy.supported_models_base.BASE({})
from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
model_config.unet_config = unet_config
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(model_config, model_type=model_type)
@ -653,7 +389,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
w.cond_stage_model = clip.cond_stage_model
load_clip_weights(w, state_dict)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = comfy.utils.load_torch_file(ckpt_path)
@ -705,7 +441,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0:
print("left over keys:", left_over)
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
@ -735,7 +471,7 @@ def load_unet(unet_path): #load unet in diffusers format
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])

View File

@ -1,6 +1,7 @@
import torch
from . import model_base
from . import utils
from . import latent_formats
def state_dict_key_replace(state_dict, keys_to_replace):
@ -33,6 +34,8 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
latent_format = latent_formats.LatentFormat
@classmethod
def matches(s, unet_config):

View File

@ -3,7 +3,7 @@ import math
import torch
import torch.nn.functional as F
import comfy.model_management
def get_canny_nms_kernel(device=None, dtype=None):
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
@ -290,8 +290,8 @@ class Canny:
CATEGORY = "image/preprocessors"
def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].repeat(1, 3, 1, 1).movedim(1, -1)
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,)
NODE_CLASS_MAPPINGS = {

View File

@ -125,6 +125,27 @@ class ImageToMask:
mask = image[0, :, :, channels.index(channel)]
return (mask,)
class ImageColorToMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, color):
temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2]
mask = torch.where(temp == color, 255, 0).float()
return (mask,)
class SolidMask:
@classmethod
def INPUT_TYPES(cls):
@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = {
"ImageCompositeMasked": ImageCompositeMasked,
"MaskToImage": MaskToImage,
"ImageToMask": ImageToMask,
"ImageColorToMask": ImageColorToMask,
"SolidMask": SolidMask,
"InvertMask": InvertMask,
"CropMask": CropMask,

View File

@ -244,14 +244,16 @@ class VAEDecode:
class VAEDecodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64})
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), )
def decode(self, vae, samples, tile_size):
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
class VAEEncode:
@classmethod
@ -280,15 +282,17 @@ class VAEEncode:
class VAEEncodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing"
def encode(self, vae, pixels):
def encode(self, vae, pixels, tile_size):
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3])
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, )
class VAEEncodeForInpaint:
@ -471,7 +475,7 @@ class DiffusersLoader:
model_path = path
break
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader:

View File

@ -1,6 +1,8 @@
import os
import sys
import asyncio
import traceback
import nodes
import folder_paths
import execution
@ -10,6 +12,7 @@ import json
import glob
import struct
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
try:
@ -79,7 +82,7 @@ class PromptServer():
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
@ -88,6 +91,8 @@ class PromptServer():
self.last_node_id = None
self.client_id = None
self.on_prompt_handlers = []
@routes.get('/ws')
async def websocket_handler(request):
ws = web.WebSocketResponse()
@ -122,7 +127,7 @@ class PromptServer():
@routes.get("/embeddings")
def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/extensions")
async def get_extensions(request):
@ -229,13 +234,17 @@ class PromptServer():
if os.path.isfile(file):
with Image.open(file) as original_pil:
metadata = PngInfo()
if hasattr(original_pil,'text'):
for key in original_pil.text:
metadata.add_text(key, original_pil.text[key])
original_pil = original_pil.convert('RGBA')
mask_pil = Image.open(image.file).convert('RGBA')
# alpha copy
new_alpha = mask_pil.getchannel('A')
original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4)
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
return image_upload(post, image_save_function)
@ -438,6 +447,7 @@ class PromptServer():
resp_code = 200
out_string = ""
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
if "number" in json_data:
number = float(json_data['number'])
@ -606,3 +616,15 @@ class PromptServer():
if call_on_start is not None:
call_on_start(address, port)
def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler)
def trigger_on_prompt(self, json_data):
for handler in self.on_prompt_handlers:
try:
json_data = handler(json_data)
except Exception as e:
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc()
return json_data

View File

@ -0,0 +1,167 @@
import {app} from "../../scripts/app.js";
function setNodeMode(node, mode) {
node.mode = mode;
node.graph.change();
}
app.registerExtension({
name: "Comfy.GroupOptions",
setup() {
const orig = LGraphCanvas.prototype.getCanvasMenuOptions;
// graph_mouse
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) {
return options;
}
// Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date
group.recomputeInsideNodes();
const nodesInGroup = group._nodes;
// No nodes in group, return default options
if (nodesInGroup.length === 0) {
return options;
} else {
// Add a separator between the default options and the group options
options.push(null);
}
// Check if all nodes are the same mode
let allNodesAreSameMode = true;
for (let i = 1; i < nodesInGroup.length; i++) {
if (nodesInGroup[i].mode !== nodesInGroup[0].mode) {
allNodesAreSameMode = false;
break;
}
}
// Modes
// 0: Always
// 1: On Event
// 2: Never
// 3: On Trigger
// 4: Bypass
// If all nodes are the same mode, add a menu option to change the mode
if (allNodesAreSameMode) {
const mode = nodesInGroup[0].mode;
switch (mode) {
case 0:
// All nodes are always, option to disable, and bypass
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
case 2:
// All nodes are never, option to enable, and bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
case 4:
// All nodes are bypass, option to enable, and disable
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
break;
default:
// All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
}
} else {
// Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
}
return options
}
}
});

View File

@ -6233,11 +6233,17 @@ LGraphNode.prototype.executeAction = function(action)
,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30]
,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/
});
skip_action = true;
}
}
}
}
if (!skip_action && this.allow_dragcanvas) {
//console.log("pointerevents: dragging_canvas start from middle button");
this.dragging_canvas = true;
}
} else if (e.which == 3 || this.pointer_is_double) {

View File

@ -299,11 +299,17 @@ export const ComfyWidgets = {
const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline;
let res;
if (multiline) {
return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
} else {
return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
}
if(inputData[1].dynamicPrompts != undefined)
res.widget.dynamicPrompts = inputData[1].dynamicPrompts;
return res;
},
COMBO(node, inputName, inputData) {
const type = inputData[0];