Merge branch 'comfyanonymous:master' into feature/toggle_migration

This commit is contained in:
Dr.Lt.Data 2023-08-31 17:04:14 +09:00 committed by GitHub
commit af2849bc36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 665 additions and 421 deletions

View File

@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
from .utils import load_torch_file, transformers_convert from .utils import load_torch_file, transformers_convert
import os import os
import torch import torch
import contextlib
import comfy.ops import comfy.ops
import comfy.model_patcher
import comfy.model_management
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config) config = CLIPVisionConfig.from_json_file(json_config)
with comfy.ops.use_comfy_ops(): self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config) self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224, self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True, do_center_crop=True,
do_convert_rgb=True, do_convert_rgb=True,
@ -27,7 +40,21 @@ class ClipVisionModel():
img = torch.clip((255. * image), 0, 255).round().int() img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img)) img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt") inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values)
for k in outputs:
t = outputs[k]
if t is not None:
outputs[k] = t.cpu()
return outputs return outputs
def convert_to_transformers(sd, prefix): def convert_to_transformers(sd, prefix):

View File

@ -1,9 +1,10 @@
import torch import torch
import math import math
import os
import comfy.utils import comfy.utils
import comfy.sd
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
import comfy.model_patcher
import comfy.cldm.cldm import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
@ -129,7 +130,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
@ -257,12 +258,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict() cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = sd[k] weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
try: try:
comfy.utils.set_attr(self.control_model, k, weight) comfy.utils.set_attr(self.control_model, k, weight)
except: except:
@ -391,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
control_model = control_model.half() control_model = control_model.half()
global_average_pooling = False global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling) control = ControlNet(control_model, global_average_pooling=global_average_pooling)
@ -468,7 +465,7 @@ def load_t2i_adapter(t2i_data):
if len(down_opts) > 0: if len(down_opts) > 0:
use_conv = True use_conv = True
xl = False xl = False
if cin == 256: if cin == 256 or cin == 768:
xl = True xl = True
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else: else:

View File

@ -1,87 +1,36 @@
import json import json
import os import os
import yaml
import folder_paths import comfy.sd
from comfy.sd import load_checkpoint
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
from . import diffusers_convert
def first_file(path, filenames):
for f in filenames:
p = os.path.join(path, f)
if os.path.exists(p):
return p
return None
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
# magic text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
v2 = diffusers_unet_conf["sample_size"] == 96 text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
if 'prediction_type' in diffusers_scheduler_conf: text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2: text_encoder_paths = [text_encoder1_path]
if v_pred: if text_encoder2_path is not None:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') text_encoder_paths.append(text_encoder2_path)
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream: unet = comfy.sd.load_unet(unet_path)
config = yaml.safe_load(stream)
model_config_params = config['model']['params'] clip = None
clip_config = model_config_params['cond_stage_config'] if output_clip:
scale_factor = model_config_params['scale_factor'] clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") vae = None
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") if output_vae:
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") vae = comfy.sd.VAE(ckpt_path=vae_path)
# Load models from safetensors if it exists, if it doesn't pytorch return (unet, clip, vae)
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)

View File

@ -56,7 +56,18 @@ class Upsample(nn.Module):
padding=1) padding=1)
def forward(self, x): def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv: if self.with_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -74,11 +85,10 @@ class Downsample(nn.Module):
stride=2, stride=2,
padding=0) padding=0)
def forward(self, x, already_padded=False): def forward(self, x):
if self.with_conv: if self.with_conv:
if not already_padded: pad = (0,1,0,1)
pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module):
# compute attention # compute attention
B, C, H, W = q.shape B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v), (q, k, v),
) )
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = ( try:
out.unsqueeze(0) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
.reshape(B, 1, out.shape[1], C) out = out.transpose(1, 2).reshape(B, C, H, W)
.permute(0, 2, 1, 3) except NotImplementedError as e:
.reshape(B, out.shape[1], C) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x+out
@ -603,9 +605,6 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
already_padded = True
# downsampling # downsampling
h = self.conv_in(x) h = self.conv_in(x)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
@ -614,8 +613,7 @@ class Encoder(nn.Module):
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1: if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h, already_padded) h = self.down[i_level].downsample(h)
already_padded = False
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)

View File

@ -118,6 +118,19 @@ def load_lora(lora, to_load):
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
w_norm = lora.get(w_norm_name, None)
b_norm = lora.get(b_norm_name, None)
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)

View File

@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from . import utils from . import utils
@ -18,8 +19,9 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config, device=device) if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
@ -93,7 +95,11 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict() unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:

View File

@ -1,6 +1,7 @@
import psutil import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
import torch import torch
import sys import sys
@ -147,15 +148,27 @@ def is_nvidia():
return True return True
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__ torch_version = torch.version.__version__
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: if torch.cuda.is_bf16_supported():
pass VAE_DTYPE = torch.bfloat16
except:
pass
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
@ -227,6 +240,7 @@ try:
except: except:
print("Could not pick default device.") print("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE)
current_loaded_models = [] current_loaded_models = []
@ -447,12 +461,8 @@ def vae_offload_device():
return torch.device("cpu") return torch.device("cpu")
def vae_dtype(): def vae_dtype():
if args.fp16_vae: global VAE_DTYPE
return torch.float16 return VAE_DTYPE
elif args.bf16_vae:
return torch.bfloat16
else:
return torch.float32
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
@ -637,6 +647,13 @@ def soft_empty_cache():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
import threading import threading

270
comfy/model_patcher.py Normal file
View File

@ -0,0 +1,270 @@
import torch
import copy
import inspect
import comfy.utils
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to

View File

@ -1,7 +1,5 @@
import torch import torch
import contextlib import contextlib
import copy
import inspect
import math import math
from comfy import model_management from comfy import model_management
@ -21,8 +19,10 @@ from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.supported_models_base
def load_model_weights(model, sd): def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -53,271 +53,6 @@ def load_clip_weights(model, sd):
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return load_model_weights(model, sd) return load_model_weights(model, sd)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = comfy.lora.model_lora_keys_unet(model.model) key_map = comfy.lora.model_lora_keys_unet(model.model)
@ -346,7 +81,7 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False): if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16 params['dtype'] = torch.float16
else: else:
@ -355,7 +90,7 @@ class CLIP:
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None self.layer_idx = None
def clone(self): def clone(self):
@ -573,7 +308,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
if model_management.should_use_fp16(): if model_management.should_use_fp16():
model = model.half() model = model.half()
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually #TODO: this function is a mess and should be removed eventually
@ -614,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
class EmptyClass: class EmptyClass:
pass pass
model_config = EmptyClass() model_config = comfy.supported_models_base.BASE({})
model_config.unet_config = unet_config
from . import latent_formats from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
model_config.unet_config = unet_config
if config['model']["target"].endswith("LatentInpaintDiffusion"): if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(model_config, model_type=model_type) model = model_base.SDInpaint(model_config, model_type=model_type)
@ -653,7 +389,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_clip_weights(w, state_dict) load_clip_weights(w, state_dict)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
@ -705,7 +441,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU") print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
@ -735,7 +471,7 @@ def load_unet(unet_path): #load unet in diffusers format
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) model_management.load_models_gpu([model, clip.load_model()])

View File

@ -1,6 +1,7 @@
import torch import torch
from . import model_base from . import model_base
from . import utils from . import utils
from . import latent_formats
def state_dict_key_replace(state_dict, keys_to_replace): def state_dict_key_replace(state_dict, keys_to_replace):
@ -33,6 +34,8 @@ class BASE:
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
beta_schedule = "linear"
latent_format = latent_formats.LatentFormat
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config):

View File

@ -3,7 +3,7 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import comfy.model_management
def get_canny_nms_kernel(device=None, dtype=None): def get_canny_nms_kernel(device=None, dtype=None):
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
@ -290,8 +290,8 @@ class Canny:
CATEGORY = "image/preprocessors" CATEGORY = "image/preprocessors"
def detect_edge(self, image, low_threshold, high_threshold): def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.movedim(-1, 1), low_threshold, high_threshold) output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].repeat(1, 3, 1, 1).movedim(1, -1) img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,) return (img_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -125,6 +125,27 @@ class ImageToMask:
mask = image[0, :, :, channels.index(channel)] mask = image[0, :, :, channels.index(channel)]
return (mask,) return (mask,)
class ImageColorToMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, color):
temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2]
mask = torch.where(temp == color, 255, 0).float()
return (mask,)
class SolidMask: class SolidMask:
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = {
"ImageCompositeMasked": ImageCompositeMasked, "ImageCompositeMasked": ImageCompositeMasked,
"MaskToImage": MaskToImage, "MaskToImage": MaskToImage,
"ImageToMask": ImageToMask, "ImageToMask": ImageToMask,
"ImageColorToMask": ImageColorToMask,
"SolidMask": SolidMask, "SolidMask": SolidMask,
"InvertMask": InvertMask, "InvertMask": InvertMask,
"CropMask": CropMask, "CropMask": CropMask,

View File

@ -244,14 +244,16 @@ class VAEDecode:
class VAEDecodeTiled: class VAEDecodeTiled:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64})
}}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def decode(self, vae, samples): def decode(self, vae, samples, tile_size):
return (vae.decode_tiled(samples["samples"]), ) return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
class VAEEncode: class VAEEncode:
@classmethod @classmethod
@ -280,15 +282,17 @@ class VAEEncode:
class VAEEncodeTiled: class VAEEncodeTiled:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels): def encode(self, vae, pixels, tile_size):
pixels = VAEEncode.vae_encode_crop_pixels(pixels) pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3]) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
@ -471,7 +475,7 @@ class DiffusersLoader:
model_path = path model_path = path
break break
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:

View File

@ -1,6 +1,8 @@
import os import os
import sys import sys
import asyncio import asyncio
import traceback
import nodes import nodes
import folder_paths import folder_paths
import execution import execution
@ -10,6 +12,7 @@ import json
import glob import glob
import struct import struct
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
try: try:
@ -79,7 +82,7 @@ class PromptServer():
if args.enable_cors_header: if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header)) middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web") os.path.realpath(__file__)), "web")
@ -88,6 +91,8 @@ class PromptServer():
self.last_node_id = None self.last_node_id = None
self.client_id = None self.client_id = None
self.on_prompt_handlers = []
@routes.get('/ws') @routes.get('/ws')
async def websocket_handler(request): async def websocket_handler(request):
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
@ -122,7 +127,7 @@ class PromptServer():
@routes.get("/embeddings") @routes.get("/embeddings")
def get_embeddings(self): def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings") embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings))) return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/extensions") @routes.get("/extensions")
async def get_extensions(request): async def get_extensions(request):
@ -229,13 +234,17 @@ class PromptServer():
if os.path.isfile(file): if os.path.isfile(file):
with Image.open(file) as original_pil: with Image.open(file) as original_pil:
metadata = PngInfo()
if hasattr(original_pil,'text'):
for key in original_pil.text:
metadata.add_text(key, original_pil.text[key])
original_pil = original_pil.convert('RGBA') original_pil = original_pil.convert('RGBA')
mask_pil = Image.open(image.file).convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA')
# alpha copy # alpha copy
new_alpha = mask_pil.getchannel('A') new_alpha = mask_pil.getchannel('A')
original_pil.putalpha(new_alpha) original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4) original_pil.save(filepath, compress_level=4, pnginfo=metadata)
return image_upload(post, image_save_function) return image_upload(post, image_save_function)
@ -438,6 +447,7 @@ class PromptServer():
resp_code = 200 resp_code = 200
out_string = "" out_string = ""
json_data = await request.json() json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
if "number" in json_data: if "number" in json_data:
number = float(json_data['number']) number = float(json_data['number'])
@ -606,3 +616,15 @@ class PromptServer():
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)
def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler)
def trigger_on_prompt(self, json_data):
for handler in self.on_prompt_handlers:
try:
json_data = handler(json_data)
except Exception as e:
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc()
return json_data

View File

@ -0,0 +1,167 @@
import {app} from "../../scripts/app.js";
function setNodeMode(node, mode) {
node.mode = mode;
node.graph.change();
}
app.registerExtension({
name: "Comfy.GroupOptions",
setup() {
const orig = LGraphCanvas.prototype.getCanvasMenuOptions;
// graph_mouse
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) {
return options;
}
// Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date
group.recomputeInsideNodes();
const nodesInGroup = group._nodes;
// No nodes in group, return default options
if (nodesInGroup.length === 0) {
return options;
} else {
// Add a separator between the default options and the group options
options.push(null);
}
// Check if all nodes are the same mode
let allNodesAreSameMode = true;
for (let i = 1; i < nodesInGroup.length; i++) {
if (nodesInGroup[i].mode !== nodesInGroup[0].mode) {
allNodesAreSameMode = false;
break;
}
}
// Modes
// 0: Always
// 1: On Event
// 2: Never
// 3: On Trigger
// 4: Bypass
// If all nodes are the same mode, add a menu option to change the mode
if (allNodesAreSameMode) {
const mode = nodesInGroup[0].mode;
switch (mode) {
case 0:
// All nodes are always, option to disable, and bypass
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
case 2:
// All nodes are never, option to enable, and bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
case 4:
// All nodes are bypass, option to enable, and disable
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
break;
default:
// All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
}
} else {
// Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
}
return options
}
}
});

View File

@ -6233,11 +6233,17 @@ LGraphNode.prototype.executeAction = function(action)
,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30] ,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30]
,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/ ,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/
}); });
skip_action = true;
} }
} }
} }
} }
if (!skip_action && this.allow_dragcanvas) {
//console.log("pointerevents: dragging_canvas start from middle button");
this.dragging_canvas = true;
}
} else if (e.which == 3 || this.pointer_is_double) { } else if (e.which == 3 || this.pointer_is_double) {

View File

@ -299,11 +299,17 @@ export const ComfyWidgets = {
const defaultVal = inputData[1].default || ""; const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline; const multiline = !!inputData[1].multiline;
let res;
if (multiline) { if (multiline) {
return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
} else { } else {
return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
} }
if(inputData[1].dynamicPrompts != undefined)
res.widget.dynamicPrompts = inputData[1].dynamicPrompts;
return res;
}, },
COMBO(node, inputName, inputData) { COMBO(node, inputName, inputData) {
const type = inputData[0]; const type = inputData[0];