mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Merge branch 'comfyanonymous:master' into fix/secure-combo
This commit is contained in:
commit
7e1f20b9f4
@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
|
||||
@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
|
||||
from .utils import load_torch_file, transformers_convert
|
||||
import os
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
with comfy.ops.use_comfy_ops():
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
|
||||
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
self.model.to(self.dtype)
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.processor = CLIPImageProcessor(crop_size=224,
|
||||
do_center_crop=True,
|
||||
do_convert_rgb=True,
|
||||
@ -27,7 +40,21 @@ class ClipVisionModel():
|
||||
img = torch.clip((255. * image), 0, 255).round().int()
|
||||
img = list(map(lambda a: a, img))
|
||||
inputs = self.processor(images=img, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = inputs['pixel_values'].to(self.load_device)
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||
outputs = self.model(pixel_values=pixel_values)
|
||||
|
||||
for k in outputs:
|
||||
t = outputs[k]
|
||||
if t is not None:
|
||||
outputs[k] = t.cpu()
|
||||
return outputs
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
@ -129,7 +130,7 @@ class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
@ -257,12 +258,7 @@ class ControlLora(ControlNet):
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
|
||||
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
||||
try:
|
||||
comfy.utils.set_attr(self.control_model, k, weight)
|
||||
except:
|
||||
@ -391,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
control_model = control_model.half()
|
||||
|
||||
global_average_pooling = False
|
||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||
@ -468,7 +465,7 @@ def load_t2i_adapter(t2i_data):
|
||||
if len(down_opts) > 0:
|
||||
use_conv = True
|
||||
xl = False
|
||||
if cin == 256:
|
||||
if cin == 256 or cin == 768:
|
||||
xl = True
|
||||
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
else:
|
||||
|
||||
@ -1,87 +1,36 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.sd import load_checkpoint
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from . import diffusers_convert
|
||||
import comfy.sd
|
||||
|
||||
def first_file(path, filenames):
|
||||
for f in filenames:
|
||||
p = os.path.join(path, f)
|
||||
if os.path.exists(p):
|
||||
return p
|
||||
return None
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
|
||||
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
|
||||
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
|
||||
text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
|
||||
text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
text_encoder_paths = [text_encoder1_path]
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
unet = comfy.sd.load_unet(unet_path)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
clip = None
|
||||
if output_clip:
|
||||
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
vae = None
|
||||
if output_vae:
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)
|
||||
return (unet, clip, vae)
|
||||
|
||||
@ -56,7 +56,18 @@ class Upsample(nn.Module):
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
try:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
b, c, h, w = x.shape
|
||||
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
||||
del x
|
||||
x = out
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
@ -74,11 +85,10 @@ class Downsample(nn.Module):
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, already_padded=False):
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if not already_padded:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
@ -603,9 +605,6 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
already_padded = True
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
@ -614,8 +613,7 @@ class Encoder(nn.Module):
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h, already_padded)
|
||||
already_padded = False
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
|
||||
@ -118,6 +118,19 @@ def load_lora(lora, to_load):
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
w_norm = lora.get(w_norm_name, None)
|
||||
b_norm = lora.get(b_norm_name, None)
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = (w_norm,)
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
print("lora key not loaded", x)
|
||||
|
||||
@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
@ -18,8 +19,9 @@ class BaseModel(torch.nn.Module):
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
@ -93,7 +95,11 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_sd = self.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
for k in unet_sd:
|
||||
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||
if self.get_dtype() == torch.float16:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import torch
|
||||
import sys
|
||||
|
||||
@ -147,15 +148,27 @@ def is_nvidia():
|
||||
return True
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
if torch.cuda.is_bf16_supported():
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
except:
|
||||
pass
|
||||
|
||||
if args.fp16_vae:
|
||||
VAE_DTYPE = torch.float16
|
||||
elif args.bf16_vae:
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
elif args.fp32_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
@ -227,6 +240,7 @@ try:
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
print("VAE dtype:", VAE_DTYPE)
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
@ -447,12 +461,8 @@ def vae_offload_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def vae_dtype():
|
||||
if args.fp16_vae:
|
||||
return torch.float16
|
||||
elif args.bf16_vae:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
global VAE_DTYPE
|
||||
return VAE_DTYPE
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
@ -637,6 +647,13 @@ def soft_empty_cache():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key):
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
|
||||
270
comfy/model_patcher.py
Normal file
270
comfy/model_patcher.py
Normal file
@ -0,0 +1,270 @@
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
import comfy.utils
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
size = 0
|
||||
for k in model_sd:
|
||||
t = model_sd[k]
|
||||
size += t.nelement() * t.element_size()
|
||||
self.size = size
|
||||
self.model_keys = set(model_sd.keys())
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
to["patches_replace"][name][(block_name, number)] = patch
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||
|
||||
def set_model_attn2_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
for k in patches:
|
||||
if k in self.model_keys:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(k, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model))
|
||||
self.patches[k] = current_patches
|
||||
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
if filter_prefix is not None:
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
if k in self.patches:
|
||||
p[k] = [model_sd[k]] + self.patches[k]
|
||||
else:
|
||||
p[k] = (model_sd[k],)
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
sd = self.model.state_dict()
|
||||
keys = list(sd.keys())
|
||||
if filter_prefix is not None:
|
||||
for k in keys:
|
||||
if not k.startswith(filter_prefix):
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None):
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", key)
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(self.offload_device)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = weight.float().to(device_to, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
comfy.utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
|
||||
if len(v) == 1:
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||
elif len(v) == 4: #lora/locon
|
||||
mat1 = v[0].float().to(weight.device)
|
||||
mat2 = v[1].float().to(weight.device)
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = v[3].float().to(weight.device)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif len(v) == 8: #lokr
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||
else:
|
||||
w1 = w1.float().to(weight.device)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
||||
else:
|
||||
w2 = w2.float().to(weight.device)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha *= v[2] / dim
|
||||
|
||||
try:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
else: #loha
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / w1b.shape[0]
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
||||
else:
|
||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||
|
||||
try:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
286
comfy/sd.py
286
comfy/sd.py
@ -1,7 +1,5 @@
|
||||
import torch
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
import math
|
||||
|
||||
from comfy import model_management
|
||||
@ -21,8 +19,10 @@ from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.supported_models_base
|
||||
|
||||
def load_model_weights(model, sd):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
@ -53,271 +53,6 @@ def load_clip_weights(model, sd):
|
||||
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
return load_model_weights(model, sd)
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
size = 0
|
||||
for k in model_sd:
|
||||
t = model_sd[k]
|
||||
size += t.nelement() * t.element_size()
|
||||
self.size = size
|
||||
self.model_keys = set(model_sd.keys())
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
to["patches_replace"][name][(block_name, number)] = patch
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||
|
||||
def set_model_attn2_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
for k in patches:
|
||||
if k in self.model_keys:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(k, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model))
|
||||
self.patches[k] = current_patches
|
||||
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
if filter_prefix is not None:
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
if k in self.patches:
|
||||
p[k] = [model_sd[k]] + self.patches[k]
|
||||
else:
|
||||
p[k] = (model_sd[k],)
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
sd = self.model.state_dict()
|
||||
keys = list(sd.keys())
|
||||
if filter_prefix is not None:
|
||||
for k in keys:
|
||||
if not k.startswith(filter_prefix):
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None):
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", k)
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(self.offload_device)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = weight.float().to(device_to, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
comfy.utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
|
||||
if len(v) == 1:
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||
elif len(v) == 4: #lora/locon
|
||||
mat1 = v[0].float().to(weight.device)
|
||||
mat2 = v[1].float().to(weight.device)
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = v[3].float().to(weight.device)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif len(v) == 8: #lokr
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||
else:
|
||||
w1 = w1.float().to(weight.device)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
||||
else:
|
||||
w2 = w2.float().to(weight.device)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha *= v[2] / dim
|
||||
|
||||
try:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
else: #loha
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / w1b.shape[0]
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
||||
else:
|
||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||
|
||||
try:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model)
|
||||
@ -346,7 +81,7 @@ class CLIP:
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = load_device
|
||||
params['device'] = offload_device
|
||||
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
||||
params['dtype'] = torch.float16
|
||||
else:
|
||||
@ -355,7 +90,7 @@ class CLIP:
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
|
||||
def clone(self):
|
||||
@ -573,7 +308,7 @@ def load_gligen(ckpt_path):
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
#TODO: this function is a mess and should be removed eventually
|
||||
@ -614,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
model_config = EmptyClass()
|
||||
model_config.unet_config = unet_config
|
||||
model_config = comfy.supported_models_base.BASE({})
|
||||
|
||||
from . import latent_formats
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
model_config.unet_config = unet_config
|
||||
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model = model_base.SDInpaint(model_config, model_type=model_type)
|
||||
@ -653,7 +389,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_clip_weights(w, state_dict)
|
||||
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
@ -705,7 +441,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if len(left_over) > 0:
|
||||
print("left over keys:", left_over)
|
||||
|
||||
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
@ -735,7 +471,7 @@ def load_unet(unet_path): #load unet in diffusers format
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
model_management.load_models_gpu([model, clip.load_model()])
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
@ -33,6 +34,8 @@ class BASE:
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
beta_schedule = "linear"
|
||||
latent_format = latent_formats.LatentFormat
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
|
||||
@ -3,7 +3,7 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
def get_canny_nms_kernel(device=None, dtype=None):
|
||||
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
|
||||
@ -290,8 +290,8 @@ class Canny:
|
||||
CATEGORY = "image/preprocessors"
|
||||
|
||||
def detect_edge(self, image, low_threshold, high_threshold):
|
||||
output = canny(image.movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
return (img_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -125,6 +125,27 @@ class ImageToMask:
|
||||
mask = image[0, :, :, channels.index(channel)]
|
||||
return (mask,)
|
||||
|
||||
class ImageColorToMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, color):
|
||||
temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int)
|
||||
temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2]
|
||||
mask = torch.where(temp == color, 255, 0).float()
|
||||
return (mask,)
|
||||
|
||||
class SolidMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageCompositeMasked": ImageCompositeMasked,
|
||||
"MaskToImage": MaskToImage,
|
||||
"ImageToMask": ImageToMask,
|
||||
"ImageColorToMask": ImageColorToMask,
|
||||
"SolidMask": SolidMask,
|
||||
"InvertMask": InvertMask,
|
||||
"CropMask": CropMask,
|
||||
|
||||
18
nodes.py
18
nodes.py
@ -244,14 +244,16 @@ class VAEDecode:
|
||||
class VAEDecodeTiled:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64})
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
return (vae.decode_tiled(samples["samples"]), )
|
||||
def decode(self, vae, samples, tile_size):
|
||||
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
|
||||
|
||||
class VAEEncode:
|
||||
@classmethod
|
||||
@ -280,15 +282,17 @@ class VAEEncode:
|
||||
class VAEEncodeTiled:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
def encode(self, vae, pixels, tile_size):
|
||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3])
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
@ -471,7 +475,7 @@ class DiffusersLoader:
|
||||
model_path = path
|
||||
break
|
||||
|
||||
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
|
||||
class unCLIPCheckpointLoader:
|
||||
|
||||
28
server.py
28
server.py
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
@ -10,6 +12,7 @@ import json
|
||||
import glob
|
||||
import struct
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
@ -79,7 +82,7 @@ class PromptServer():
|
||||
if args.enable_cors_header:
|
||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||
|
||||
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
|
||||
self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.web_root = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), "web")
|
||||
@ -88,6 +91,8 @@ class PromptServer():
|
||||
self.last_node_id = None
|
||||
self.client_id = None
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@routes.get('/ws')
|
||||
async def websocket_handler(request):
|
||||
ws = web.WebSocketResponse()
|
||||
@ -122,7 +127,7 @@ class PromptServer():
|
||||
@routes.get("/embeddings")
|
||||
def get_embeddings(self):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@routes.get("/extensions")
|
||||
async def get_extensions(request):
|
||||
@ -229,13 +234,17 @@ class PromptServer():
|
||||
|
||||
if os.path.isfile(file):
|
||||
with Image.open(file) as original_pil:
|
||||
metadata = PngInfo()
|
||||
if hasattr(original_pil,'text'):
|
||||
for key in original_pil.text:
|
||||
metadata.add_text(key, original_pil.text[key])
|
||||
original_pil = original_pil.convert('RGBA')
|
||||
mask_pil = Image.open(image.file).convert('RGBA')
|
||||
|
||||
# alpha copy
|
||||
new_alpha = mask_pil.getchannel('A')
|
||||
original_pil.putalpha(new_alpha)
|
||||
original_pil.save(filepath, compress_level=4)
|
||||
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
|
||||
|
||||
return image_upload(post, image_save_function)
|
||||
|
||||
@ -438,6 +447,7 @@ class PromptServer():
|
||||
resp_code = 200
|
||||
out_string = ""
|
||||
json_data = await request.json()
|
||||
json_data = self.trigger_on_prompt(json_data)
|
||||
|
||||
if "number" in json_data:
|
||||
number = float(json_data['number'])
|
||||
@ -606,3 +616,15 @@ class PromptServer():
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
def add_on_prompt_handler(self, handler):
|
||||
self.on_prompt_handlers.append(handler)
|
||||
|
||||
def trigger_on_prompt(self, json_data):
|
||||
for handler in self.on_prompt_handlers:
|
||||
try:
|
||||
json_data = handler(json_data)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
||||
traceback.print_exc()
|
||||
|
||||
return json_data
|
||||
|
||||
167
web/extensions/core/groupOptions.js
Normal file
167
web/extensions/core/groupOptions.js
Normal file
@ -0,0 +1,167 @@
|
||||
import {app} from "../../scripts/app.js";
|
||||
|
||||
function setNodeMode(node, mode) {
|
||||
node.mode = mode;
|
||||
node.graph.change();
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.GroupOptions",
|
||||
setup() {
|
||||
const orig = LGraphCanvas.prototype.getCanvasMenuOptions;
|
||||
// graph_mouse
|
||||
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
|
||||
const options = orig.apply(this, arguments);
|
||||
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
|
||||
if (!group) {
|
||||
return options;
|
||||
}
|
||||
|
||||
// Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date
|
||||
group.recomputeInsideNodes();
|
||||
const nodesInGroup = group._nodes;
|
||||
|
||||
// No nodes in group, return default options
|
||||
if (nodesInGroup.length === 0) {
|
||||
return options;
|
||||
} else {
|
||||
// Add a separator between the default options and the group options
|
||||
options.push(null);
|
||||
}
|
||||
|
||||
// Check if all nodes are the same mode
|
||||
let allNodesAreSameMode = true;
|
||||
for (let i = 1; i < nodesInGroup.length; i++) {
|
||||
if (nodesInGroup[i].mode !== nodesInGroup[0].mode) {
|
||||
allNodesAreSameMode = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Modes
|
||||
// 0: Always
|
||||
// 1: On Event
|
||||
// 2: Never
|
||||
// 3: On Trigger
|
||||
// 4: Bypass
|
||||
// If all nodes are the same mode, add a menu option to change the mode
|
||||
if (allNodesAreSameMode) {
|
||||
const mode = nodesInGroup[0].mode;
|
||||
switch (mode) {
|
||||
case 0:
|
||||
// All nodes are always, option to disable, and bypass
|
||||
options.push({
|
||||
content: "Set Group Nodes to Never",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 2);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Bypass Group Nodes",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 4);
|
||||
}
|
||||
}
|
||||
});
|
||||
break;
|
||||
case 2:
|
||||
// All nodes are never, option to enable, and bypass
|
||||
options.push({
|
||||
content: "Set Group Nodes to Always",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 0);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Bypass Group Nodes",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 4);
|
||||
}
|
||||
}
|
||||
});
|
||||
break;
|
||||
case 4:
|
||||
// All nodes are bypass, option to enable, and disable
|
||||
options.push({
|
||||
content: "Set Group Nodes to Always",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 0);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Set Group Nodes to Never",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 2);
|
||||
}
|
||||
}
|
||||
});
|
||||
break;
|
||||
default:
|
||||
// All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass
|
||||
options.push({
|
||||
content: "Set Group Nodes to Always",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 0);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Set Group Nodes to Never",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 2);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Bypass Group Nodes",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 4);
|
||||
}
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass
|
||||
options.push({
|
||||
content: "Set Group Nodes to Always",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 0);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Set Group Nodes to Never",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 2);
|
||||
}
|
||||
}
|
||||
});
|
||||
options.push({
|
||||
content: "Bypass Group Nodes",
|
||||
callback: () => {
|
||||
for (const node of nodesInGroup) {
|
||||
setNodeMode(node, 4);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return options
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -6233,11 +6233,17 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30]
|
||||
,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/
|
||||
});
|
||||
|
||||
skip_action = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!skip_action && this.allow_dragcanvas) {
|
||||
//console.log("pointerevents: dragging_canvas start from middle button");
|
||||
this.dragging_canvas = true;
|
||||
}
|
||||
|
||||
|
||||
} else if (e.which == 3 || this.pointer_is_double) {
|
||||
|
||||
|
||||
@ -299,11 +299,17 @@ export const ComfyWidgets = {
|
||||
const defaultVal = inputData[1].default || "";
|
||||
const multiline = !!inputData[1].multiline;
|
||||
|
||||
let res;
|
||||
if (multiline) {
|
||||
return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
|
||||
res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
|
||||
} else {
|
||||
return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
|
||||
res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
|
||||
}
|
||||
|
||||
if(inputData[1].dynamicPrompts != undefined)
|
||||
res.widget.dynamicPrompts = inputData[1].dynamicPrompts;
|
||||
|
||||
return res;
|
||||
},
|
||||
COMBO(node, inputName, inputData) {
|
||||
const type = inputData[0];
|
||||
|
||||
Loading…
Reference in New Issue
Block a user