mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
Merge branch 'comfyanonymous:master' into bugfix/extra_data
This commit is contained in:
commit
7f229b7499
@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
|||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
|
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||||
|
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
|
|||||||
@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
|
|||||||
from .utils import load_torch_file, transformers_convert
|
from .utils import load_torch_file, transformers_convert
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
with comfy.ops.use_comfy_ops():
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
|
self.dtype = torch.float32
|
||||||
|
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||||
|
self.dtype = torch.float16
|
||||||
|
|
||||||
|
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.model = CLIPVisionModelWithProjection(config)
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
|
self.model.to(self.dtype)
|
||||||
|
|
||||||
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
self.processor = CLIPImageProcessor(crop_size=224,
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
do_convert_rgb=True,
|
do_convert_rgb=True,
|
||||||
@ -27,7 +40,21 @@ class ClipVisionModel():
|
|||||||
img = torch.clip((255. * image), 0, 255).round().int()
|
img = torch.clip((255. * image), 0, 255).round().int()
|
||||||
img = list(map(lambda a: a, img))
|
img = list(map(lambda a: a, img))
|
||||||
inputs = self.processor(images=img, return_tensors="pt")
|
inputs = self.processor(images=img, return_tensors="pt")
|
||||||
outputs = self.model(**inputs)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
pixel_values = inputs['pixel_values'].to(self.load_device)
|
||||||
|
|
||||||
|
if self.dtype != torch.float32:
|
||||||
|
precision_scope = torch.autocast
|
||||||
|
else:
|
||||||
|
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||||
|
|
||||||
|
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||||
|
outputs = self.model(pixel_values=pixel_values)
|
||||||
|
|
||||||
|
for k in outputs:
|
||||||
|
t = outputs[k]
|
||||||
|
if t is not None:
|
||||||
|
outputs[k] = t.cpu()
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def convert_to_transformers(sd, prefix):
|
def convert_to_transformers(sd, prefix):
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.sd
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
|
import comfy.model_patcher
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
@ -129,7 +130,7 @@ class ControlNet(ControlBase):
|
|||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
@ -257,12 +258,7 @@ class ControlLora(ControlNet):
|
|||||||
cm = self.control_model.state_dict()
|
cm = self.control_model.state_dict()
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
weight = sd[k]
|
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
||||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
|
||||||
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
|
||||||
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
|
||||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
comfy.utils.set_attr(self.control_model, k, weight)
|
comfy.utils.set_attr(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
@ -391,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
control_model = control_model.half()
|
control_model = control_model.half()
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
@ -468,7 +465,7 @@ def load_t2i_adapter(t2i_data):
|
|||||||
if len(down_opts) > 0:
|
if len(down_opts) > 0:
|
||||||
use_conv = True
|
use_conv = True
|
||||||
xl = False
|
xl = False
|
||||||
if cin == 256:
|
if cin == 256 or cin == 768:
|
||||||
xl = True
|
xl = True
|
||||||
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,87 +1,36 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import yaml
|
|
||||||
|
|
||||||
import folder_paths
|
import comfy.sd
|
||||||
from comfy.sd import load_checkpoint
|
|
||||||
import os.path as osp
|
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file, save_file
|
|
||||||
from . import diffusers_convert
|
|
||||||
|
|
||||||
|
def first_file(path, filenames):
|
||||||
|
for f in filenames:
|
||||||
|
p = os.path.join(path, f)
|
||||||
|
if os.path.exists(p):
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
|
||||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
|
||||||
|
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
|
||||||
|
|
||||||
# magic
|
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
|
||||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
|
||||||
if 'prediction_type' in diffusers_scheduler_conf:
|
text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
|
||||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
|
||||||
|
|
||||||
if v2:
|
text_encoder_paths = [text_encoder1_path]
|
||||||
if v_pred:
|
if text_encoder2_path is not None:
|
||||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
text_encoder_paths.append(text_encoder2_path)
|
||||||
else:
|
|
||||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
|
||||||
else:
|
|
||||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
|
||||||
|
|
||||||
with open(config_path, 'r') as stream:
|
unet = comfy.sd.load_unet(unet_path)
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
model_config_params = config['model']['params']
|
clip = None
|
||||||
clip_config = model_config_params['cond_stage_config']
|
if output_clip:
|
||||||
scale_factor = model_config_params['scale_factor']
|
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
|
||||||
vae_config = model_config_params['first_stage_config']
|
|
||||||
vae_config['scale_factor'] = scale_factor
|
|
||||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
|
||||||
|
|
||||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
vae = None
|
||||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
if output_vae:
|
||||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||||
|
|
||||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
return (unet, clip, vae)
|
||||||
if osp.exists(unet_path):
|
|
||||||
unet_state_dict = load_file(unet_path, device="cpu")
|
|
||||||
else:
|
|
||||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
|
||||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
|
||||||
|
|
||||||
if osp.exists(vae_path):
|
|
||||||
vae_state_dict = load_file(vae_path, device="cpu")
|
|
||||||
else:
|
|
||||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
|
||||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
|
||||||
|
|
||||||
if osp.exists(text_enc_path):
|
|
||||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
|
||||||
else:
|
|
||||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
|
||||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
|
||||||
|
|
||||||
# Convert the UNet model
|
|
||||||
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
|
|
||||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
|
||||||
|
|
||||||
# Convert the VAE model
|
|
||||||
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
|
|
||||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
|
||||||
|
|
||||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
|
||||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
|
||||||
|
|
||||||
if is_v20_model:
|
|
||||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
|
||||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
|
||||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
|
|
||||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
|
||||||
else:
|
|
||||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
|
|
||||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
|
||||||
|
|
||||||
# Put together new checkpoint
|
|
||||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
|
||||||
|
|
||||||
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)
|
|
||||||
|
|||||||
@ -56,7 +56,18 @@ class Upsample(nn.Module):
|
|||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
try:
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
except: #operation not implemented for bf16
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
||||||
|
split = 8
|
||||||
|
l = out.shape[1] // split
|
||||||
|
for i in range(0, out.shape[1], l):
|
||||||
|
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
||||||
|
del x
|
||||||
|
x = out
|
||||||
|
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
@ -74,9 +85,8 @@ class Downsample(nn.Module):
|
|||||||
stride=2,
|
stride=2,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x, already_padded=False):
|
def forward(self, x):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
if not already_padded:
|
|
||||||
pad = (0,1,0,1)
|
pad = (0,1,0,1)
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
B, C, H, W = q.shape
|
B, C, H, W = q.shape
|
||||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||||
.reshape(B, t.shape[1], 1, C)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(B * 1, t.shape[1], C)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
|
||||||
|
|
||||||
out = (
|
try:
|
||||||
out.unsqueeze(0)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
.reshape(B, 1, out.shape[1], C)
|
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||||
.permute(0, 2, 1, 3)
|
except NotImplementedError as e:
|
||||||
.reshape(B, out.shape[1], C)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
)
|
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x+out
|
||||||
|
|
||||||
@ -603,9 +605,6 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
temb = None
|
temb = None
|
||||||
pad = (0,1,0,1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
already_padded = True
|
|
||||||
# downsampling
|
# downsampling
|
||||||
h = self.conv_in(x)
|
h = self.conv_in(x)
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
@ -614,8 +613,7 @@ class Encoder(nn.Module):
|
|||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
h = self.down[i_level].attn[i_block](h)
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions-1:
|
||||||
h = self.down[i_level].downsample(h, already_padded)
|
h = self.down[i_level].downsample(h)
|
||||||
already_padded = False
|
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = self.mid.block_1(h, temb)
|
h = self.mid.block_1(h, temb)
|
||||||
|
|||||||
@ -118,6 +118,19 @@ def load_lora(lora, to_load):
|
|||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||||
|
|
||||||
|
|
||||||
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
|
w_norm = lora.get(w_norm_name, None)
|
||||||
|
b_norm = lora.get(b_norm_name, None)
|
||||||
|
|
||||||
|
if w_norm is not None:
|
||||||
|
loaded_keys.add(w_norm_name)
|
||||||
|
patch_dict[to_load[x]] = (w_norm,)
|
||||||
|
if b_norm is not None:
|
||||||
|
loaded_keys.add(b_norm_name)
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||||
|
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
print("lora key not loaded", x)
|
print("lora key not loaded", x)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
|||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
|
import comfy.model_management
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
@ -18,7 +19,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
@ -93,7 +95,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_sd = self.diffusion_model.state_dict()
|
||||||
|
unet_state_dict = {}
|
||||||
|
for k in unet_sd:
|
||||||
|
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
if self.get_dtype() == torch.float16:
|
if self.get_dtype() == torch.float16:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import psutil
|
import psutil
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.utils
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -147,16 +148,28 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
VAE_DTYPE = torch.bfloat16
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if args.fp16_vae:
|
||||||
|
VAE_DTYPE = torch.float16
|
||||||
|
elif args.bf16_vae:
|
||||||
|
VAE_DTYPE = torch.bfloat16
|
||||||
|
elif args.fp32_vae:
|
||||||
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
@ -227,6 +240,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
print("Could not pick default device.")
|
print("Could not pick default device.")
|
||||||
|
|
||||||
|
print("VAE dtype:", VAE_DTYPE)
|
||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
@ -447,12 +461,8 @@ def vae_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype():
|
def vae_dtype():
|
||||||
if args.fp16_vae:
|
global VAE_DTYPE
|
||||||
return torch.float16
|
return VAE_DTYPE
|
||||||
elif args.bf16_vae:
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
@ -637,6 +647,13 @@ def soft_empty_cache():
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
def resolve_lowvram_weight(weight, model, key):
|
||||||
|
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||||
|
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||||
|
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
|
||||||
|
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||||
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|||||||
270
comfy/model_patcher.py
Normal file
270
comfy/model_patcher.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
class ModelPatcher:
|
||||||
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||||
|
self.size = size
|
||||||
|
self.model = model
|
||||||
|
self.patches = {}
|
||||||
|
self.backup = {}
|
||||||
|
self.model_options = {"transformer_options":{}}
|
||||||
|
self.model_size()
|
||||||
|
self.load_device = load_device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
if current_device is None:
|
||||||
|
self.current_device = self.offload_device
|
||||||
|
else:
|
||||||
|
self.current_device = current_device
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size > 0:
|
||||||
|
return self.size
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
size = 0
|
||||||
|
for k in model_sd:
|
||||||
|
t = model_sd[k]
|
||||||
|
size += t.nelement() * t.element_size()
|
||||||
|
self.size = size
|
||||||
|
self.model_keys = set(model_sd.keys())
|
||||||
|
return size
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||||
|
n.patches = {}
|
||||||
|
for k in self.patches:
|
||||||
|
n.patches[k] = self.patches[k][:]
|
||||||
|
|
||||||
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
|
n.model_keys = self.model_keys
|
||||||
|
return n
|
||||||
|
|
||||||
|
def is_clone(self, other):
|
||||||
|
if hasattr(other, 'model') and self.model is other.model:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
|
else:
|
||||||
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
|
def set_model_patch(self, patch, name):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" not in to:
|
||||||
|
to["patches"] = {}
|
||||||
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
to["patches_replace"][name][(block_name, number)] = patch
|
||||||
|
|
||||||
|
def set_model_attn1_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn1_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn2_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn1_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
|
def model_patches_to(self, device):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
|
||||||
|
def model_dtype(self):
|
||||||
|
if hasattr(self.model, "get_dtype"):
|
||||||
|
return self.model.get_dtype()
|
||||||
|
|
||||||
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
|
p = set()
|
||||||
|
for k in patches:
|
||||||
|
if k in self.model_keys:
|
||||||
|
p.add(k)
|
||||||
|
current_patches = self.patches.get(k, [])
|
||||||
|
current_patches.append((strength_patch, patches[k], strength_model))
|
||||||
|
self.patches[k] = current_patches
|
||||||
|
|
||||||
|
return list(p)
|
||||||
|
|
||||||
|
def get_key_patches(self, filter_prefix=None):
|
||||||
|
model_sd = self.model_state_dict()
|
||||||
|
p = {}
|
||||||
|
for k in model_sd:
|
||||||
|
if filter_prefix is not None:
|
||||||
|
if not k.startswith(filter_prefix):
|
||||||
|
continue
|
||||||
|
if k in self.patches:
|
||||||
|
p[k] = [model_sd[k]] + self.patches[k]
|
||||||
|
else:
|
||||||
|
p[k] = (model_sd[k],)
|
||||||
|
return p
|
||||||
|
|
||||||
|
def model_state_dict(self, filter_prefix=None):
|
||||||
|
sd = self.model.state_dict()
|
||||||
|
keys = list(sd.keys())
|
||||||
|
if filter_prefix is not None:
|
||||||
|
for k in keys:
|
||||||
|
if not k.startswith(filter_prefix):
|
||||||
|
sd.pop(k)
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None):
|
||||||
|
model_sd = self.model_state_dict()
|
||||||
|
for key in self.patches:
|
||||||
|
if key not in model_sd:
|
||||||
|
print("could not patch. key doesn't exist in model:", key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight = model_sd[key]
|
||||||
|
|
||||||
|
if key not in self.backup:
|
||||||
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
temp_weight = weight.float().to(device_to, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
comfy.utils.set_attr(self.model, key, out_weight)
|
||||||
|
del temp_weight
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
self.model.to(device_to)
|
||||||
|
self.current_device = device_to
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def calculate_weight(self, patches, weight, key):
|
||||||
|
for p in patches:
|
||||||
|
alpha = p[0]
|
||||||
|
v = p[1]
|
||||||
|
strength_model = p[2]
|
||||||
|
|
||||||
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
w1 = v[0]
|
||||||
|
if alpha != 0.0:
|
||||||
|
if w1.shape != weight.shape:
|
||||||
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||||
|
elif len(v) == 4: #lora/locon
|
||||||
|
mat1 = v[0].float().to(weight.device)
|
||||||
|
mat2 = v[1].float().to(weight.device)
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha *= v[2] / mat2.shape[0]
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = v[3].float().to(weight.device)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
|
elif len(v) == 8: #lokr
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||||
|
else:
|
||||||
|
w1 = w1.float().to(weight.device)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
||||||
|
else:
|
||||||
|
w2 = w2.float().to(weight.device)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha *= v[2] / dim
|
||||||
|
|
||||||
|
try:
|
||||||
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
|
else: #loha
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha *= v[2] / w1b.shape[0]
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||||
|
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||||
|
|
||||||
|
try:
|
||||||
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def unpatch_model(self, device_to=None):
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
|
for k in keys:
|
||||||
|
comfy.utils.set_attr(self.model, k, self.backup[k])
|
||||||
|
|
||||||
|
self.backup = {}
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
self.model.to(device_to)
|
||||||
|
self.current_device = device_to
|
||||||
286
comfy/sd.py
286
comfy/sd.py
@ -1,7 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
|
||||||
import inspect
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
@ -21,8 +19,10 @@ from . import sd1_clip
|
|||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
|
||||||
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
|
import comfy.supported_models_base
|
||||||
|
|
||||||
def load_model_weights(model, sd):
|
def load_model_weights(model, sd):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -53,271 +53,6 @@ def load_clip_weights(model, sd):
|
|||||||
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||||
return load_model_weights(model, sd)
|
return load_model_weights(model, sd)
|
||||||
|
|
||||||
class ModelPatcher:
|
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
|
||||||
self.size = size
|
|
||||||
self.model = model
|
|
||||||
self.patches = {}
|
|
||||||
self.backup = {}
|
|
||||||
self.model_options = {"transformer_options":{}}
|
|
||||||
self.model_size()
|
|
||||||
self.load_device = load_device
|
|
||||||
self.offload_device = offload_device
|
|
||||||
if current_device is None:
|
|
||||||
self.current_device = self.offload_device
|
|
||||||
else:
|
|
||||||
self.current_device = current_device
|
|
||||||
|
|
||||||
def model_size(self):
|
|
||||||
if self.size > 0:
|
|
||||||
return self.size
|
|
||||||
model_sd = self.model.state_dict()
|
|
||||||
size = 0
|
|
||||||
for k in model_sd:
|
|
||||||
t = model_sd[k]
|
|
||||||
size += t.nelement() * t.element_size()
|
|
||||||
self.size = size
|
|
||||||
self.model_keys = set(model_sd.keys())
|
|
||||||
return size
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
|
||||||
n.patches = {}
|
|
||||||
for k in self.patches:
|
|
||||||
n.patches[k] = self.patches[k][:]
|
|
||||||
|
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
|
||||||
n.model_keys = self.model_keys
|
|
||||||
return n
|
|
||||||
|
|
||||||
def is_clone(self, other):
|
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
|
||||||
else:
|
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches" not in to:
|
|
||||||
to["patches"] = {}
|
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
|
||||||
|
|
||||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches_replace" not in to:
|
|
||||||
to["patches_replace"] = {}
|
|
||||||
if name not in to["patches_replace"]:
|
|
||||||
to["patches_replace"][name] = {}
|
|
||||||
to["patches_replace"][name][(block_name, number)] = patch
|
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
|
||||||
|
|
||||||
def set_model_attn2_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
|
||||||
|
|
||||||
def set_model_attn1_replace(self, patch, block_name, number):
|
|
||||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
|
||||||
|
|
||||||
def set_model_attn2_replace(self, patch, block_name, number):
|
|
||||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
|
||||||
|
|
||||||
def set_model_attn1_output_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn1_output_patch")
|
|
||||||
|
|
||||||
def set_model_attn2_output_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn2_output_patch")
|
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches" in to:
|
|
||||||
patches = to["patches"]
|
|
||||||
for name in patches:
|
|
||||||
patch_list = patches[name]
|
|
||||||
for i in range(len(patch_list)):
|
|
||||||
if hasattr(patch_list[i], "to"):
|
|
||||||
patch_list[i] = patch_list[i].to(device)
|
|
||||||
if "patches_replace" in to:
|
|
||||||
patches = to["patches_replace"]
|
|
||||||
for name in patches:
|
|
||||||
patch_list = patches[name]
|
|
||||||
for k in patch_list:
|
|
||||||
if hasattr(patch_list[k], "to"):
|
|
||||||
patch_list[k] = patch_list[k].to(device)
|
|
||||||
|
|
||||||
def model_dtype(self):
|
|
||||||
if hasattr(self.model, "get_dtype"):
|
|
||||||
return self.model.get_dtype()
|
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
||||||
p = set()
|
|
||||||
for k in patches:
|
|
||||||
if k in self.model_keys:
|
|
||||||
p.add(k)
|
|
||||||
current_patches = self.patches.get(k, [])
|
|
||||||
current_patches.append((strength_patch, patches[k], strength_model))
|
|
||||||
self.patches[k] = current_patches
|
|
||||||
|
|
||||||
return list(p)
|
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
p = {}
|
|
||||||
for k in model_sd:
|
|
||||||
if filter_prefix is not None:
|
|
||||||
if not k.startswith(filter_prefix):
|
|
||||||
continue
|
|
||||||
if k in self.patches:
|
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
|
||||||
else:
|
|
||||||
p[k] = (model_sd[k],)
|
|
||||||
return p
|
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
|
||||||
sd = self.model.state_dict()
|
|
||||||
keys = list(sd.keys())
|
|
||||||
if filter_prefix is not None:
|
|
||||||
for k in keys:
|
|
||||||
if not k.startswith(filter_prefix):
|
|
||||||
sd.pop(k)
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def patch_model(self, device_to=None):
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
for key in self.patches:
|
|
||||||
if key not in model_sd:
|
|
||||||
print("could not patch. key doesn't exist in model:", k)
|
|
||||||
continue
|
|
||||||
|
|
||||||
weight = model_sd[key]
|
|
||||||
|
|
||||||
if key not in self.backup:
|
|
||||||
self.backup[key] = weight.to(self.offload_device)
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
temp_weight = weight.float().to(device_to, copy=True)
|
|
||||||
else:
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
||||||
comfy.utils.set_attr(self.model, key, out_weight)
|
|
||||||
del temp_weight
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.current_device = device_to
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
|
||||||
for p in patches:
|
|
||||||
alpha = p[0]
|
|
||||||
v = p[1]
|
|
||||||
strength_model = p[2]
|
|
||||||
|
|
||||||
if strength_model != 1.0:
|
|
||||||
weight *= strength_model
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
w1 = v[0]
|
|
||||||
if alpha != 0.0:
|
|
||||||
if w1.shape != weight.shape:
|
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
||||||
else:
|
|
||||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
|
||||||
elif len(v) == 4: #lora/locon
|
|
||||||
mat1 = v[0].float().to(weight.device)
|
|
||||||
mat2 = v[1].float().to(weight.device)
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha *= v[2] / mat2.shape[0]
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = v[3].float().to(weight.device)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
|
||||||
except Exception as e:
|
|
||||||
print("ERROR", key, e)
|
|
||||||
elif len(v) == 8: #lokr
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
|
||||||
else:
|
|
||||||
w1 = w1.float().to(weight.device)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
|
||||||
else:
|
|
||||||
w2 = w2.float().to(weight.device)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha *= v[2] / dim
|
|
||||||
|
|
||||||
try:
|
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
|
||||||
except Exception as e:
|
|
||||||
print("ERROR", key, e)
|
|
||||||
else: #loha
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha *= v[2] / w1b.shape[0]
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
|
||||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
|
||||||
|
|
||||||
try:
|
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
|
||||||
except Exception as e:
|
|
||||||
print("ERROR", key, e)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None):
|
|
||||||
keys = list(self.backup.keys())
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
|
||||||
|
|
||||||
self.backup = {}
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.current_device = device_to
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model)
|
key_map = comfy.lora.model_lora_keys_unet(model.model)
|
||||||
@ -346,7 +81,7 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = load_device
|
params['device'] = offload_device
|
||||||
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
||||||
params['dtype'] = torch.float16
|
params['dtype'] = torch.float16
|
||||||
else:
|
else:
|
||||||
@ -355,7 +90,7 @@ class CLIP:
|
|||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@ -573,7 +308,7 @@ def load_gligen(ckpt_path):
|
|||||||
model = gligen.load_gligen(data)
|
model = gligen.load_gligen(data)
|
||||||
if model_management.should_use_fp16():
|
if model_management.should_use_fp16():
|
||||||
model = model.half()
|
model = model.half()
|
||||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||||
|
|
||||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
#TODO: this function is a mess and should be removed eventually
|
#TODO: this function is a mess and should be removed eventually
|
||||||
@ -614,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = EmptyClass()
|
model_config = comfy.supported_models_base.BASE({})
|
||||||
model_config.unet_config = unet_config
|
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
|
model_config.unet_config = unet_config
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
model = model_base.SDInpaint(model_config, model_type=model_type)
|
model = model_base.SDInpaint(model_config, model_type=model_type)
|
||||||
@ -653,7 +389,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_clip_weights(w, state_dict)
|
load_clip_weights(w, state_dict)
|
||||||
|
|
||||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
@ -705,7 +441,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
print("loaded straight to GPU")
|
print("loaded straight to GPU")
|
||||||
model_management.load_model_gpu(model_patcher)
|
model_management.load_model_gpu(model_patcher)
|
||||||
@ -735,7 +471,7 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
model_management.load_models_gpu([model, clip.load_model()])
|
model_management.load_models_gpu([model, clip.load_model()])
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from . import latent_formats
|
||||||
|
|
||||||
|
|
||||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||||
@ -33,6 +34,8 @@ class BASE:
|
|||||||
clip_prefix = []
|
clip_prefix = []
|
||||||
clip_vision_prefix = None
|
clip_vision_prefix = None
|
||||||
noise_aug_config = None
|
noise_aug_config = None
|
||||||
|
beta_schedule = "linear"
|
||||||
|
latent_format = latent_formats.LatentFormat
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config):
|
def matches(s, unet_config):
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
def get_canny_nms_kernel(device=None, dtype=None):
|
def get_canny_nms_kernel(device=None, dtype=None):
|
||||||
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
|
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
|
||||||
@ -290,8 +290,8 @@ class Canny:
|
|||||||
CATEGORY = "image/preprocessors"
|
CATEGORY = "image/preprocessors"
|
||||||
|
|
||||||
def detect_edge(self, image, low_threshold, high_threshold):
|
def detect_edge(self, image, low_threshold, high_threshold):
|
||||||
output = canny(image.movedim(-1, 1), low_threshold, high_threshold)
|
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||||
img_out = output[1].repeat(1, 3, 1, 1).movedim(1, -1)
|
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
|
||||||
return (img_out,)
|
return (img_out,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -125,6 +125,27 @@ class ImageToMask:
|
|||||||
mask = image[0, :, :, channels.index(channel)]
|
mask = image[0, :, :, channels.index(channel)]
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
|
class ImageColorToMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "mask"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
FUNCTION = "image_to_mask"
|
||||||
|
|
||||||
|
def image_to_mask(self, image, color):
|
||||||
|
temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int)
|
||||||
|
temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2]
|
||||||
|
mask = torch.where(temp == color, 255, 0).float()
|
||||||
|
return (mask,)
|
||||||
|
|
||||||
class SolidMask:
|
class SolidMask:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ImageCompositeMasked": ImageCompositeMasked,
|
"ImageCompositeMasked": ImageCompositeMasked,
|
||||||
"MaskToImage": MaskToImage,
|
"MaskToImage": MaskToImage,
|
||||||
"ImageToMask": ImageToMask,
|
"ImageToMask": ImageToMask,
|
||||||
|
"ImageColorToMask": ImageColorToMask,
|
||||||
"SolidMask": SolidMask,
|
"SolidMask": SolidMask,
|
||||||
"InvertMask": InvertMask,
|
"InvertMask": InvertMask,
|
||||||
"CropMask": CropMask,
|
"CropMask": CropMask,
|
||||||
|
|||||||
18
nodes.py
18
nodes.py
@ -244,14 +244,16 @@ class VAEDecode:
|
|||||||
class VAEDecodeTiled:
|
class VAEDecodeTiled:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||||
|
"tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64})
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples, tile_size):
|
||||||
return (vae.decode_tiled(samples["samples"]), )
|
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
|
||||||
|
|
||||||
class VAEEncode:
|
class VAEEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -280,15 +282,17 @@ class VAEEncode:
|
|||||||
class VAEEncodeTiled:
|
class VAEEncodeTiled:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||||
|
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def encode(self, vae, pixels):
|
def encode(self, vae, pixels, tile_size):
|
||||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
||||||
t = vae.encode_tiled(pixels[:,:,:,:3])
|
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
class VAEEncodeForInpaint:
|
class VAEEncodeForInpaint:
|
||||||
@ -471,7 +475,7 @@ class DiffusersLoader:
|
|||||||
model_path = path
|
model_path = path
|
||||||
break
|
break
|
||||||
|
|
||||||
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
|
|
||||||
class unCLIPCheckpointLoader:
|
class unCLIPCheckpointLoader:
|
||||||
|
|||||||
28
server.py
28
server.py
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import execution
|
import execution
|
||||||
@ -10,6 +12,7 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -79,7 +82,7 @@ class PromptServer():
|
|||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
|
||||||
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
|
self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.web_root = os.path.join(os.path.dirname(
|
self.web_root = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), "web")
|
os.path.realpath(__file__)), "web")
|
||||||
@ -88,6 +91,8 @@ class PromptServer():
|
|||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.client_id = None
|
self.client_id = None
|
||||||
|
|
||||||
|
self.on_prompt_handlers = []
|
||||||
|
|
||||||
@routes.get('/ws')
|
@routes.get('/ws')
|
||||||
async def websocket_handler(request):
|
async def websocket_handler(request):
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
@ -122,7 +127,7 @@ class PromptServer():
|
|||||||
@routes.get("/embeddings")
|
@routes.get("/embeddings")
|
||||||
def get_embeddings(self):
|
def get_embeddings(self):
|
||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
@routes.get("/extensions")
|
@routes.get("/extensions")
|
||||||
async def get_extensions(request):
|
async def get_extensions(request):
|
||||||
@ -229,13 +234,17 @@ class PromptServer():
|
|||||||
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
with Image.open(file) as original_pil:
|
with Image.open(file) as original_pil:
|
||||||
|
metadata = PngInfo()
|
||||||
|
if hasattr(original_pil,'text'):
|
||||||
|
for key in original_pil.text:
|
||||||
|
metadata.add_text(key, original_pil.text[key])
|
||||||
original_pil = original_pil.convert('RGBA')
|
original_pil = original_pil.convert('RGBA')
|
||||||
mask_pil = Image.open(image.file).convert('RGBA')
|
mask_pil = Image.open(image.file).convert('RGBA')
|
||||||
|
|
||||||
# alpha copy
|
# alpha copy
|
||||||
new_alpha = mask_pil.getchannel('A')
|
new_alpha = mask_pil.getchannel('A')
|
||||||
original_pil.putalpha(new_alpha)
|
original_pil.putalpha(new_alpha)
|
||||||
original_pil.save(filepath, compress_level=4)
|
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
|
||||||
|
|
||||||
return image_upload(post, image_save_function)
|
return image_upload(post, image_save_function)
|
||||||
|
|
||||||
@ -438,6 +447,7 @@ class PromptServer():
|
|||||||
resp_code = 200
|
resp_code = 200
|
||||||
out_string = ""
|
out_string = ""
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
|
||||||
if "number" in json_data:
|
if "number" in json_data:
|
||||||
number = float(json_data['number'])
|
number = float(json_data['number'])
|
||||||
@ -606,3 +616,15 @@ class PromptServer():
|
|||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(address, port)
|
call_on_start(address, port)
|
||||||
|
|
||||||
|
def add_on_prompt_handler(self, handler):
|
||||||
|
self.on_prompt_handlers.append(handler)
|
||||||
|
|
||||||
|
def trigger_on_prompt(self, json_data):
|
||||||
|
for handler in self.on_prompt_handlers:
|
||||||
|
try:
|
||||||
|
json_data = handler(json_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return json_data
|
||||||
|
|||||||
167
web/extensions/core/groupOptions.js
Normal file
167
web/extensions/core/groupOptions.js
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
import {app} from "../../scripts/app.js";
|
||||||
|
|
||||||
|
function setNodeMode(node, mode) {
|
||||||
|
node.mode = mode;
|
||||||
|
node.graph.change();
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.GroupOptions",
|
||||||
|
setup() {
|
||||||
|
const orig = LGraphCanvas.prototype.getCanvasMenuOptions;
|
||||||
|
// graph_mouse
|
||||||
|
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
|
||||||
|
const options = orig.apply(this, arguments);
|
||||||
|
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
|
||||||
|
if (!group) {
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date
|
||||||
|
group.recomputeInsideNodes();
|
||||||
|
const nodesInGroup = group._nodes;
|
||||||
|
|
||||||
|
// No nodes in group, return default options
|
||||||
|
if (nodesInGroup.length === 0) {
|
||||||
|
return options;
|
||||||
|
} else {
|
||||||
|
// Add a separator between the default options and the group options
|
||||||
|
options.push(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if all nodes are the same mode
|
||||||
|
let allNodesAreSameMode = true;
|
||||||
|
for (let i = 1; i < nodesInGroup.length; i++) {
|
||||||
|
if (nodesInGroup[i].mode !== nodesInGroup[0].mode) {
|
||||||
|
allNodesAreSameMode = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modes
|
||||||
|
// 0: Always
|
||||||
|
// 1: On Event
|
||||||
|
// 2: Never
|
||||||
|
// 3: On Trigger
|
||||||
|
// 4: Bypass
|
||||||
|
// If all nodes are the same mode, add a menu option to change the mode
|
||||||
|
if (allNodesAreSameMode) {
|
||||||
|
const mode = nodesInGroup[0].mode;
|
||||||
|
switch (mode) {
|
||||||
|
case 0:
|
||||||
|
// All nodes are always, option to disable, and bypass
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Never",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Bypass Group Nodes",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
// All nodes are never, option to enable, and bypass
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Always",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Bypass Group Nodes",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
// All nodes are bypass, option to enable, and disable
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Always",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Never",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Always",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Never",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Bypass Group Nodes",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Always",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Set Group Nodes to Never",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.push({
|
||||||
|
content: "Bypass Group Nodes",
|
||||||
|
callback: () => {
|
||||||
|
for (const node of nodesInGroup) {
|
||||||
|
setNodeMode(node, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
@ -6233,11 +6233,17 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30]
|
,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30]
|
||||||
,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/
|
,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/
|
||||||
});
|
});
|
||||||
|
skip_action = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!skip_action && this.allow_dragcanvas) {
|
||||||
|
//console.log("pointerevents: dragging_canvas start from middle button");
|
||||||
|
this.dragging_canvas = true;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (e.which == 3 || this.pointer_is_double) {
|
} else if (e.which == 3 || this.pointer_is_double) {
|
||||||
|
|
||||||
|
|||||||
@ -299,11 +299,17 @@ export const ComfyWidgets = {
|
|||||||
const defaultVal = inputData[1].default || "";
|
const defaultVal = inputData[1].default || "";
|
||||||
const multiline = !!inputData[1].multiline;
|
const multiline = !!inputData[1].multiline;
|
||||||
|
|
||||||
|
let res;
|
||||||
if (multiline) {
|
if (multiline) {
|
||||||
return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
|
res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
|
||||||
} else {
|
} else {
|
||||||
return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
|
res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(inputData[1].dynamicPrompts != undefined)
|
||||||
|
res.widget.dynamicPrompts = inputData[1].dynamicPrompts;
|
||||||
|
|
||||||
|
return res;
|
||||||
},
|
},
|
||||||
COMBO(node, inputName, inputData) {
|
COMBO(node, inputName, inputData) {
|
||||||
const type = inputData[0];
|
const type = inputData[0];
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user