diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4f22f319..fda245433 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") -fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") +fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index a887e51b5..daaa2f2bf 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm from .utils import load_torch_file, transformers_convert import os import torch +import contextlib + import comfy.ops +import comfy.model_patcher +import comfy.model_management class ClipVisionModel(): def __init__(self, json_config): config = CLIPVisionConfig.from_json_file(json_config) - with comfy.ops.use_comfy_ops(): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + + with comfy.ops.use_comfy_ops(offload_device, self.dtype): with modeling_utils.no_init_weights(): self.model = CLIPVisionModelWithProjection(config) + self.model.to(self.dtype) + + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.processor = CLIPImageProcessor(crop_size=224, do_center_crop=True, do_convert_rgb=True, @@ -27,7 +40,21 @@ class ClipVisionModel(): img = torch.clip((255. * image), 0, 255).round().int() img = list(map(lambda a: a, img)) inputs = self.processor(images=img, return_tensors="pt") - outputs = self.model(**inputs) + comfy.model_management.load_model_gpu(self.patcher) + pixel_values = inputs['pixel_values'].to(self.load_device) + + if self.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, b: contextlib.nullcontext(a) + + with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): + outputs = self.model(pixel_values=pixel_values) + + for k in outputs: + t = outputs[k] + if t is not None: + outputs[k] = t.cpu() return outputs def convert_to_transformers(sd, prefix): diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 5279307c0..83e1be058 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,9 +1,10 @@ import torch import math +import os import comfy.utils -import comfy.sd import comfy.model_management import comfy.model_detection +import comfy.model_patcher import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -129,7 +130,7 @@ class ControlNet(ControlBase): def __init__(self, control_model, global_average_pooling=False, device=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): @@ -257,12 +258,7 @@ class ControlLora(ControlNet): cm = self.control_model.state_dict() for k in sd: - weight = sd[k] - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] - + weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) try: comfy.utils.set_attr(self.control_model, k, weight) except: @@ -391,7 +387,8 @@ def load_controlnet(ckpt_path, model=None): control_model = control_model.half() global_average_pooling = False - if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True control = ControlNet(control_model, global_average_pooling=global_average_pooling) @@ -468,7 +465,7 @@ def load_t2i_adapter(t2i_data): if len(down_opts) > 0: use_conv = True xl = False - if cin == 256: + if cin == 256 or cin == 768: xl = True model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) else: diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 11d94c340..a52e0102b 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -1,87 +1,36 @@ import json import os -import yaml -import folder_paths -from comfy.sd import load_checkpoint -import os.path as osp -import re -import torch -from safetensors.torch import load_file, save_file -from . import diffusers_convert +import comfy.sd +def first_file(path, filenames): + for f in filenames: + p = os.path.join(path, f) + if os.path.exists(p): + return p + return None -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) +def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None): + diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"] + unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) + vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names) - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"] + text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names) + text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names) - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + text_encoder_paths = [text_encoder1_path] + if text_encoder2_path is not None: + text_encoder_paths.append(text_encoder2_path) - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) + unet = comfy.sd.load_unet(unet_path) - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + clip = None + if output_clip: + clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory) - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + vae = None + if output_vae: + vae = comfy.sd.VAE(ckpt_path=vae_path) - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config) + return (unet, clip, vae) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index b596408d3..431548483 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -56,7 +56,18 @@ class Upsample(nn.Module): padding=1) def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + try: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + except: #operation not implemented for bf16 + b, c, h, w = x.shape + out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device) + split = 8 + l = out.shape[1] // split + for i in range(0, out.shape[1], l): + out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype) + del x + x = out + if self.with_conv: x = self.conv(x) return x @@ -74,11 +85,10 @@ class Downsample(nn.Module): stride=2, padding=0) - def forward(self, x, already_padded=False): + def forward(self, x): if self.with_conv: - if not already_padded: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), (q, k, v), ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + try: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + out = out.transpose(1, 2).reshape(B, C, H, W) + except NotImplementedError as e: + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out @@ -603,9 +605,6 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - already_padded = True # downsampling h = self.conv_in(x) for i_level in range(self.num_resolutions): @@ -614,8 +613,7 @@ class Encoder(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h, already_padded) - already_padded = False + h = self.down[i_level].downsample(h) # middle h = self.mid.block_1(h, temb) diff --git a/comfy/lora.py b/comfy/lora.py index d685a455e..3009a1c9e 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -118,6 +118,19 @@ def load_lora(lora, to_load): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + + w_norm_name = "{}.w_norm".format(x) + b_norm_name = "{}.b_norm".format(x) + w_norm = lora.get(w_norm_name, None) + b_norm = lora.get(b_norm_name, None) + + if w_norm is not None: + loaded_keys.add(w_norm_name) + patch_dict[to_load[x]] = (w_norm,) + if b_norm is not None: + loaded_keys.add(b_norm_name) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + for x in lora.keys(): if x not in loaded_keys: print("lora key not loaded", x) diff --git a/comfy/model_base.py b/comfy/model_base.py index 979e2c65e..acd4169a8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep +import comfy.model_management import numpy as np from enum import Enum from . import utils @@ -18,8 +19,9 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config - self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) - self.diffusion_model = UNetModel(**unet_config, device=device) + self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + if not unet_config.get("disable_unet_model_creation", False): + self.diffusion_model = UNetModel(**unet_config, device=device) self.model_type = model_type self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: @@ -93,7 +95,11 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_state_dict = self.diffusion_model.state_dict() + unet_sd = self.diffusion_model.state_dict() + unet_state_dict = {} + for k in unet_sd: + unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: diff --git a/comfy/model_management.py b/comfy/model_management.py index 016434492..aca8af999 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import comfy.utils import torch import sys @@ -147,15 +148,27 @@ def is_nvidia(): return True ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +VAE_DTYPE = torch.float32 -if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - try: - if is_nvidia(): - torch_version = torch.version.__version__ - if int(torch_version[0]) >= 2: + +try: + if is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - except: - pass + if torch.cuda.is_bf16_supported(): + VAE_DTYPE = torch.bfloat16 +except: + pass + +if args.fp16_vae: + VAE_DTYPE = torch.float16 +elif args.bf16_vae: + VAE_DTYPE = torch.bfloat16 +elif args.fp32_vae: + VAE_DTYPE = torch.float32 + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) @@ -227,6 +240,7 @@ try: except: print("Could not pick default device.") +print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] @@ -447,12 +461,8 @@ def vae_offload_device(): return torch.device("cpu") def vae_dtype(): - if args.fp16_vae: - return torch.float16 - elif args.bf16_vae: - return torch.bfloat16 - else: - return torch.float32 + global VAE_DTYPE + return VAE_DTYPE def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -637,6 +647,13 @@ def soft_empty_cache(): torch.cuda.empty_cache() torch.cuda.ipc_collect() +def resolve_lowvram_weight(weight, model, key): + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = comfy.utils.get_attr(model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + return weight + #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py new file mode 100644 index 000000000..a6ee0bae1 --- /dev/null +++ b/comfy/model_patcher.py @@ -0,0 +1,270 @@ +import torch +import copy +import inspect + +import comfy.utils + +class ModelPatcher: + def __init__(self, model, load_device, offload_device, size=0, current_device=None): + self.size = size + self.model = model + self.patches = {} + self.backup = {} + self.model_options = {"transformer_options":{}} + self.model_size() + self.load_device = load_device + self.offload_device = offload_device + if current_device is None: + self.current_device = self.offload_device + else: + self.current_device = current_device + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + self.model_keys = set(model_sd.keys()) + return size + + def clone(self): + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + + n.model_options = copy.deepcopy(self.model_options) + n.model_keys = self.model_keys + return n + + def is_clone(self, other): + if hasattr(other, 'model') and self.model is other.model: + return True + return False + + def set_model_sampler_cfg_function(self, sampler_cfg_function): + if len(inspect.signature(sampler_cfg_function).parameters) == 3: + self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way + else: + self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_unet_function_wrapper(self, unet_wrapper_function): + self.model_options["model_function_wrapper"] = unet_wrapper_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_patch_replace(self, patch, name, block_name, number): + to = self.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][(block_name, number)] = patch + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn1", block_name, number) + + def set_model_attn2_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn2", block_name, number) + + def set_model_attn1_output_patch(self, patch): + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch): + self.set_model_patch(patch, "attn2_output_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + patch_list[k] = patch_list[k].to(device) + + def model_dtype(self): + if hasattr(self.model, "get_dtype"): + return self.model.get_dtype() + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + p = set() + for k in patches: + if k in self.model_keys: + p.add(k) + current_patches = self.patches.get(k, []) + current_patches.append((strength_patch, patches[k], strength_model)) + self.patches[k] = current_patches + + return list(p) + + def get_key_patches(self, filter_prefix=None): + model_sd = self.model_state_dict() + p = {} + for k in model_sd: + if filter_prefix is not None: + if not k.startswith(filter_prefix): + continue + if k in self.patches: + p[k] = [model_sd[k]] + self.patches[k] + else: + p[k] = (model_sd[k],) + return p + + def model_state_dict(self, filter_prefix=None): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd + + def patch_model(self, device_to=None): + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", key) + continue + + weight = model_sd[key] + + if key not in self.backup: + self.backup[key] = weight.to(self.offload_device) + + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + comfy.utils.set_attr(self.model, key, out_weight) + del temp_weight + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + return self.model + + def calculate_weight(self, patches, weight, key): + for p in patches: + alpha = p[0] + v = p[1] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key), ) + + if len(v) == 1: + w1 = v[0] + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon + mat1 = v[0].float().to(weight.device) + mat2 = v[1].float().to(weight.device) + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = v[3].float().to(weight.device) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + else: + w1 = w1.float().to(weight.device) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + else: + w2 = w2.float().to(weight.device) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + else: #loha + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) + else: + m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) + m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + + return weight + + def unpatch_model(self, device_to=None): + keys = list(self.backup.keys()) + + for k in keys: + comfy.utils.set_attr(self.model, k, self.backup[k]) + + self.backup = {} + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to diff --git a/comfy/sd.py b/comfy/sd.py index 7462c79ef..e98dabe88 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,7 +1,5 @@ import torch import contextlib -import copy -import inspect import math from comfy import model_management @@ -21,8 +19,10 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip +import comfy.model_patcher import comfy.lora import comfy.t2i_adapter.adapter +import comfy.supported_models_base def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -53,271 +53,6 @@ def load_clip_weights(model, sd): sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return load_model_weights(model, sd) -class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None): - self.size = size - self.model = model - self.patches = {} - self.backup = {} - self.model_options = {"transformer_options":{}} - self.model_size() - self.load_device = load_device - self.offload_device = offload_device - if current_device is None: - self.current_device = self.offload_device - else: - self.current_device = current_device - - def model_size(self): - if self.size > 0: - return self.size - model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size - self.model_keys = set(model_sd.keys()) - return size - - def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - - n.model_options = copy.deepcopy(self.model_options) - n.model_keys = self.model_keys - return n - - def is_clone(self, other): - if hasattr(other, 'model') and self.model is other.model: - return True - return False - - def set_model_sampler_cfg_function(self, sampler_cfg_function): - if len(inspect.signature(sampler_cfg_function).parameters) == 3: - self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way - else: - self.model_options["sampler_cfg_function"] = sampler_cfg_function - - def set_model_unet_function_wrapper(self, unet_wrapper_function): - self.model_options["model_function_wrapper"] = unet_wrapper_function - - def set_model_patch(self, patch, name): - to = self.model_options["transformer_options"] - if "patches" not in to: - to["patches"] = {} - to["patches"][name] = to["patches"].get(name, []) + [patch] - - def set_model_patch_replace(self, patch, name, block_name, number): - to = self.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch - - def set_model_attn1_patch(self, patch): - self.set_model_patch(patch, "attn1_patch") - - def set_model_attn2_patch(self, patch): - self.set_model_patch(patch, "attn2_patch") - - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) - - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) - - def set_model_attn1_output_patch(self, patch): - self.set_model_patch(patch, "attn1_output_patch") - - def set_model_attn2_output_patch(self, patch): - self.set_model_patch(patch, "attn2_output_patch") - - def model_patches_to(self, device): - to = self.model_options["transformer_options"] - if "patches" in to: - patches = to["patches"] - for name in patches: - patch_list = patches[name] - for i in range(len(patch_list)): - if hasattr(patch_list[i], "to"): - patch_list[i] = patch_list[i].to(device) - if "patches_replace" in to: - patches = to["patches_replace"] - for name in patches: - patch_list = patches[name] - for k in patch_list: - if hasattr(patch_list[k], "to"): - patch_list[k] = patch_list[k].to(device) - - def model_dtype(self): - if hasattr(self.model, "get_dtype"): - return self.model.get_dtype() - - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - for k in patches: - if k in self.model_keys: - p.add(k) - current_patches = self.patches.get(k, []) - current_patches.append((strength_patch, patches[k], strength_model)) - self.patches[k] = current_patches - - return list(p) - - def get_key_patches(self, filter_prefix=None): - model_sd = self.model_state_dict() - p = {} - for k in model_sd: - if filter_prefix is not None: - if not k.startswith(filter_prefix): - continue - if k in self.patches: - p[k] = [model_sd[k]] + self.patches[k] - else: - p[k] = (model_sd[k],) - return p - - def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd - - def patch_model(self, device_to=None): - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", k) - continue - - weight = model_sd[key] - - if key not in self.backup: - self.backup[key] = weight.to(self.offload_device) - - if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - - return self.model - - def calculate_weight(self, patches, weight, key): - for p in patches: - alpha = p[0] - v = p[1] - strength_model = p[2] - - if strength_model != 1.0: - weight *= strength_model - - if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key), ) - - if len(v) == 1: - w1 = v[0] - if alpha != 0.0: - if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += alpha * w1.type(weight.dtype).to(weight.device) - elif len(v) == 4: #lora/locon - mat1 = v[0].float().to(weight.device) - mat2 = v[1].float().to(weight.device) - if v[2] is not None: - alpha *= v[2] / mat2.shape[0] - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].float().to(weight.device) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - elif len(v) == 8: #lokr - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) - else: - w1 = w1.float().to(weight.device) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) - else: - w2 = w2.float().to(weight.device) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha *= v[2] / dim - - try: - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - else: #loha - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha *= v[2] / w1b.shape[0] - w2a = v[3] - w2b = v[4] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) - else: - m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) - m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) - - try: - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - - return weight - - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) - - for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) - - self.backup = {} - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = comfy.lora.model_lora_keys_unet(model.model) @@ -346,7 +81,7 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - params['device'] = load_device + params['device'] = offload_device if model_management.should_use_fp16(load_device, prioritize_performance=False): params['dtype'] = torch.float16 else: @@ -355,7 +90,7 @@ class CLIP: self.cond_stage_model = clip(**(params)) self.tokenizer = tokenizer(embedding_directory=embedding_directory) - self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None def clone(self): @@ -573,7 +308,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): #TODO: this function is a mess and should be removed eventually @@ -614,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl class EmptyClass: pass - model_config = EmptyClass() - model_config.unet_config = unet_config + model_config = comfy.supported_models_base.BASE({}) + from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) + model_config.unet_config = unet_config if config['model']["target"].endswith("LatentInpaintDiffusion"): model = model_base.SDInpaint(model_config, model_type=model_type) @@ -653,7 +389,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl w.cond_stage_model = clip.cond_stage_model load_clip_weights(w, state_dict) - return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) + return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): sd = comfy.utils.load_torch_file(ckpt_path) @@ -705,7 +441,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -735,7 +471,7 @@ def load_unet(unet_path): #load unet in diffusers format model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index d0088bbd5..c9cd54d0e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -1,6 +1,7 @@ import torch from . import model_base from . import utils +from . import latent_formats def state_dict_key_replace(state_dict, keys_to_replace): @@ -33,6 +34,8 @@ class BASE: clip_prefix = [] clip_vision_prefix = None noise_aug_config = None + beta_schedule = "linear" + latent_format = latent_formats.LatentFormat @classmethod def matches(s, unet_config): diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index d7c3f132f..94d453f2c 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -3,7 +3,7 @@ import math import torch import torch.nn.functional as F - +import comfy.model_management def get_canny_nms_kernel(device=None, dtype=None): """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" @@ -290,8 +290,8 @@ class Canny: CATEGORY = "image/preprocessors" def detect_edge(self, image, low_threshold, high_threshold): - output = canny(image.movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].repeat(1, 3, 1, 1).movedim(1, -1) + output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) return (img_out,) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 5adb468ac..43f623a62 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -125,6 +125,27 @@ class ImageToMask: mask = image[0, :, :, channels.index(channel)] return (mask,) +class ImageColorToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, color): + temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int) + temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2] + mask = torch.where(temp == color, 255, 0).float() + return (mask,) + class SolidMask: @classmethod def INPUT_TYPES(cls): @@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = { "ImageCompositeMasked": ImageCompositeMasked, "MaskToImage": MaskToImage, "ImageToMask": ImageToMask, + "ImageColorToMask": ImageColorToMask, "SolidMask": SolidMask, "InvertMask": InvertMask, "CropMask": CropMask, diff --git a/nodes.py b/nodes.py index 233bc8d40..5e755f149 100644 --- a/nodes.py +++ b/nodes.py @@ -244,14 +244,16 @@ class VAEDecode: class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), + "tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64}) + }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "_for_testing" - def decode(self, vae, samples): - return (vae.decode_tiled(samples["samples"]), ) + def decode(self, vae, samples, tile_size): + return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), ) class VAEEncode: @classmethod @@ -280,15 +282,17 @@ class VAEEncode: class VAEEncodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} + return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), + "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) + }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" - def encode(self, vae, pixels): + def encode(self, vae, pixels, tile_size): pixels = VAEEncode.vae_encode_crop_pixels(pixels) - t = vae.encode_tiled(pixels[:,:,:,:3]) + t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) class VAEEncodeForInpaint: @@ -471,7 +475,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: diff --git a/server.py b/server.py index d1295342b..57d5a65df 100644 --- a/server.py +++ b/server.py @@ -1,6 +1,8 @@ import os import sys import asyncio +import traceback + import nodes import folder_paths import execution @@ -10,6 +12,7 @@ import json import glob import struct from PIL import Image, ImageOps +from PIL.PngImagePlugin import PngInfo from io import BytesIO try: @@ -79,7 +82,7 @@ class PromptServer(): if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) - self.app = web.Application(client_max_size=20971520, middlewares=middlewares) + self.app = web.Application(client_max_size=104857600, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") @@ -88,6 +91,8 @@ class PromptServer(): self.last_node_id = None self.client_id = None + self.on_prompt_handlers = [] + @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse() @@ -122,7 +127,7 @@ class PromptServer(): @routes.get("/embeddings") def get_embeddings(self): embeddings = folder_paths.get_filename_list("embeddings") - return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings))) + return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) @routes.get("/extensions") async def get_extensions(request): @@ -229,13 +234,17 @@ class PromptServer(): if os.path.isfile(file): with Image.open(file) as original_pil: + metadata = PngInfo() + if hasattr(original_pil,'text'): + for key in original_pil.text: + metadata.add_text(key, original_pil.text[key]) original_pil = original_pil.convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA') # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath, compress_level=4) + original_pil.save(filepath, compress_level=4, pnginfo=metadata) return image_upload(post, image_save_function) @@ -438,6 +447,7 @@ class PromptServer(): resp_code = 200 out_string = "" json_data = await request.json() + json_data = self.trigger_on_prompt(json_data) if "number" in json_data: number = float(json_data['number']) @@ -606,3 +616,15 @@ class PromptServer(): if call_on_start is not None: call_on_start(address, port) + def add_on_prompt_handler(self, handler): + self.on_prompt_handlers.append(handler) + + def trigger_on_prompt(self, json_data): + for handler in self.on_prompt_handlers: + try: + json_data = handler(json_data) + except Exception as e: + print(f"[ERROR] An error occurred during the on_prompt_handler processing") + traceback.print_exc() + + return json_data diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js new file mode 100644 index 000000000..1d935e90a --- /dev/null +++ b/web/extensions/core/groupOptions.js @@ -0,0 +1,167 @@ +import {app} from "../../scripts/app.js"; + +function setNodeMode(node, mode) { + node.mode = mode; + node.graph.change(); +} + +app.registerExtension({ + name: "Comfy.GroupOptions", + setup() { + const orig = LGraphCanvas.prototype.getCanvasMenuOptions; + // graph_mouse + LGraphCanvas.prototype.getCanvasMenuOptions = function () { + const options = orig.apply(this, arguments); + const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); + if (!group) { + return options; + } + + // Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date + group.recomputeInsideNodes(); + const nodesInGroup = group._nodes; + + // No nodes in group, return default options + if (nodesInGroup.length === 0) { + return options; + } else { + // Add a separator between the default options and the group options + options.push(null); + } + + // Check if all nodes are the same mode + let allNodesAreSameMode = true; + for (let i = 1; i < nodesInGroup.length; i++) { + if (nodesInGroup[i].mode !== nodesInGroup[0].mode) { + allNodesAreSameMode = false; + break; + } + } + + // Modes + // 0: Always + // 1: On Event + // 2: Never + // 3: On Trigger + // 4: Bypass + // If all nodes are the same mode, add a menu option to change the mode + if (allNodesAreSameMode) { + const mode = nodesInGroup[0].mode; + switch (mode) { + case 0: + // All nodes are always, option to disable, and bypass + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + break; + case 2: + // All nodes are never, option to enable, and bypass + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + break; + case 4: + // All nodes are bypass, option to enable, and disable + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + break; + default: + // All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + break; + } + } else { + // Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + } + + return options + } + } +}); diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 356c71ac2..4bb2f0d99 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -6233,11 +6233,17 @@ LGraphNode.prototype.executeAction = function(action) ,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30] ,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/ }); - + skip_action = true; } } } } + + if (!skip_action && this.allow_dragcanvas) { + //console.log("pointerevents: dragging_canvas start from middle button"); + this.dragging_canvas = true; + } + } else if (e.which == 3 || this.pointer_is_double) { diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index adf5f26fa..5a4644b13 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -299,11 +299,17 @@ export const ComfyWidgets = { const defaultVal = inputData[1].default || ""; const multiline = !!inputData[1].multiline; + let res; if (multiline) { - return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); + res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); } else { - return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; + res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; } + + if(inputData[1].dynamicPrompts != undefined) + res.widget.dynamicPrompts = inputData[1].dynamicPrompts; + + return res; }, COMBO(node, inputName, inputData) { const type = inputData[0];