merge upstream

This commit is contained in:
doctorpangloss 2023-08-29 13:36:53 -07:00
commit db673f7728
22 changed files with 1418 additions and 1082 deletions

View File

@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@ -2,14 +2,28 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
from .utils import load_torch_file, transformers_convert from .utils import load_torch_file, transformers_convert
import os import os
import torch import torch
import contextlib
from . import ops from . import ops
import comfy.ops
import comfy.model_patcher
import comfy.model_management
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config) config = CLIPVisionConfig.from_json_file(json_config)
with ops.use_comfy_ops(): self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config) self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224, self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True, do_center_crop=True,
do_convert_rgb=True, do_convert_rgb=True,
@ -27,7 +41,21 @@ class ClipVisionModel():
img = torch.clip((255. * image), 0, 255).round().int() img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img)) img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt") inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values)
for k in outputs:
t = outputs[k]
if t is not None:
outputs[k] = t.cpu()
return outputs return outputs
def convert_to_transformers(sd, prefix): def convert_to_transformers(sd, prefix):

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import traceback
import glob import glob
import struct import struct
import sys import sys
@ -7,6 +8,7 @@ import shutil
from urllib.parse import quote from urllib.parse import quote
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
import json import json
@ -98,7 +100,7 @@ class PromptServer():
if args.enable_cors_header: if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header)) middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, handler_args={'max_field_size': 16380}, self.app = web.Application(client_max_size=104857600, handler_args={'max_field_size': 16380},
middlewares=middlewares) middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web") web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web")
@ -111,6 +113,8 @@ class PromptServer():
self.last_node_id = None self.last_node_id = None
self.client_id = None self.client_id = None
self.on_prompt_handlers = []
@routes.get('/ws') @routes.get('/ws')
async def websocket_handler(request): async def websocket_handler(request):
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
@ -252,13 +256,17 @@ class PromptServer():
if os.path.isfile(file): if os.path.isfile(file):
with Image.open(file) as original_pil: with Image.open(file) as original_pil:
metadata = PngInfo()
if hasattr(original_pil,'text'):
for key in original_pil.text:
metadata.add_text(key, original_pil.text[key])
original_pil = original_pil.convert('RGBA') original_pil = original_pil.convert('RGBA')
mask_pil = Image.open(image.file).convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA')
# alpha copy # alpha copy
new_alpha = mask_pil.getchannel('A') new_alpha = mask_pil.getchannel('A')
original_pil.putalpha(new_alpha) original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4) original_pil.save(filepath, compress_level=4, pnginfo=metadata)
return image_upload(post, image_save_function) return image_upload(post, image_save_function)
@ -463,6 +471,7 @@ class PromptServer():
resp_code = 200 resp_code = 200
out_string = "" out_string = ""
json_data = await request.json() json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
if "number" in json_data: if "number" in json_data:
number = float(json_data['number']) number = float(json_data['number'])
@ -761,6 +770,19 @@ class PromptServer():
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)
def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler)
def trigger_on_prompt(self, json_data):
for handler in self.on_prompt_handlers:
try:
json_data = handler(json_data)
except Exception as e:
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc()
return json_data
@classmethod @classmethod
def get_output_path(cls, subfolder: str | None = None, filename: str | None = None): def get_output_path(cls, subfolder: str | None = None, filename: str | None = None):
paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""] paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""]

480
comfy/controlnet.py Normal file
View File

@ -0,0 +1,480 @@
import torch
import math
import os
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.model_patcher
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
per_batch = target_batch_size // batched_number
tensor = tensor[:per_batch]
if per_batch > tensor.shape[0]:
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
current_batch_size = tensor.shape[0]
if current_batch_size == target_batch_size:
return tensor
else:
return torch.cat([tensor] * batched_number, dim=0)
class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_range = None
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
return self
def pre_run(self, model, percent_to_timestep_function):
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
if self.previous_controlnet is not None:
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.timestep_range = None
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
return self.previous_controlnet.inference_memory_requirements(dtype)
return 0
def control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
for i in range(len(control_prev[x])):
prev_val = control_prev[x][i]
if i >= len(o):
o.append(prev_val)
elif prev_val is not None:
if o[i] is None:
o[i] = prev_val
else:
o[i] += prev_val
return out
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return {}
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = torch.cat(cond['c_crossattn'], 1)
y = cond.get('c_adm', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
self.copy_to(c)
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
return out
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.up = None
self.down = None
self.bias = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = 0
self.groups = groups
self.padding_mode = padding_mode
self.weight = None
self.bias = None
self.up = None
self.down = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
dtype = model.get_dtype()
self.control_model.to(dtype)
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
try:
comfy.utils.set_attr(self.control_model, k, weight)
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
self.copy_to(c)
return c
def cleanup(self):
del self.control_model
self.control_model = None
super().cleanup()
def get_models(self):
out = ControlBase.get_models(self)
return out
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
k_in = "controlnet_down_blocks.{}{}".format(count, s)
k_out = "zero_convs.{}.0{}".format(count, s)
if k_in not in controlnet_data:
loop = False
break
diffusers_keys[k_in] = k_out
count += 1
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
if count == 0:
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
else:
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
k_out = "input_hint_block.{}{}".format(count * 2, s)
if k_in not in controlnet_data:
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
loop = False
diffusers_keys[k_in] = k_out
count += 1
new_sd = {}
for k in diffusers_keys:
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys)
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
key = 'zero_convs.0.0.weight'
if pth_key in controlnet_data:
pth = True
key = pth_key
prefix = "control_model."
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net
if controlnet_config is None:
use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
if pth:
if 'difference' in controlnet_data:
if model is not None:
comfy.model_management.load_models_gpu([model])
model_sd = model.model_state_dict()
for x in controlnet_data:
c_m = "control_model."
if x.startswith(c_m):
sd_key = "diffusion_model.{}".format(x[len(c_m):])
if sd_key in model_sd:
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
if use_fp16:
control_model = control_model.half()
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, device=None):
super().__init__(device)
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
width = math.ceil(width / unshuffle_amount) * unshuffle_amount
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return {}
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.control_input = None
self.cond_hint = None
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(x_noisy.dtype)
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
mid = None
if self.t2i_model.xl == True:
mid = control_input[-1:]
control_input = control_input[:-1]
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in)
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
if 'adapter' in keys:
t2i_data = t2i_data['adapter']
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1]
channel = t2i_data['conv_in.weight'].shape[0]
ksize = t2i_data['body.0.block2.weight'].shape[2]
use_conv = False
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
if len(down_opts) > 0:
use_conv = True
xl = False
if cin == 256:
xl = True
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else:
return None
missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0:
print("t2i missing", missing)
if len(unexpected) > 0:
print("t2i unexpected", unexpected)
return T2IAdapter(model_ad, model_ad.input_channels)

View File

@ -56,7 +56,18 @@ class Upsample(nn.Module):
padding=1) padding=1)
def forward(self, x): def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv: if self.with_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -74,11 +85,10 @@ class Downsample(nn.Module):
stride=2, stride=2,
padding=0) padding=0)
def forward(self, x, already_padded=False): def forward(self, x):
if self.with_conv: if self.with_conv:
if not already_padded: pad = (0,1,0,1)
pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module):
# compute attention # compute attention
B, C, H, W = q.shape B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v), (q, k, v),
) )
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = ( try:
out.unsqueeze(0) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
.reshape(B, 1, out.shape[1], C) out = out.transpose(1, 2).reshape(B, C, H, W)
.permute(0, 2, 1, 3) except NotImplementedError as e:
.reshape(B, out.shape[1], C) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x+out
@ -603,9 +605,6 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
already_padded = True
# downsampling # downsampling
h = self.conv_in(x) h = self.conv_in(x)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
@ -614,8 +613,7 @@ class Encoder(nn.Module):
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1: if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h, already_padded) h = self.down[i_level].downsample(h)
already_padded = False
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)

199
comfy/lora.py Normal file
View File

@ -0,0 +1,199 @@
import comfy.utils
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
"mlp.fc2": "mlp_fc2",
"self_attn.k_proj": "self_attn_k_proj",
"self_attn.q_proj": "self_attn_q_proj",
"self_attn.v_proj": "self_attn_v_proj",
"self_attn.out_proj": "self_attn_out_proj",
}
def load_lora(lora, to_load):
patch_dict = {}
loaded_keys = set()
for x in to_load:
alpha_name = "{}.alpha".format(x)
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
loaded_keys.add(A_name)
loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
w_norm = lora.get(w_norm_name, None)
b_norm = lora.get(b_norm_name, None)
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
return patch_dict
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32):
for c in LORA_CLIP_MAP:
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
else:
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
return key_map
def model_lora_keys_unet(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
return key_map

View File

@ -3,6 +3,7 @@ from .ldm.modules.diffusionmodules.openaimodel import UNetModel
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.util import make_beta_schedule from .ldm.modules.diffusionmodules.util import make_beta_schedule
from .ldm.modules.diffusionmodules.openaimodel import Timestep from .ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from . import utils from . import utils
@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict() unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:

View File

@ -1,6 +1,7 @@
import psutil import psutil
from enum import Enum from enum import Enum
from .cli_args import args from .cli_args import args
import comfy.utils
import torch import torch
import sys import sys
@ -111,9 +112,6 @@ if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096: if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError
@ -150,15 +148,27 @@ def is_nvidia():
return True return True
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__ torch_version = torch.version.__version__
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: if torch.cuda.is_bf16_supported():
pass VAE_DTYPE = torch.bfloat16
except:
pass
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
@ -230,6 +240,7 @@ try:
except: except:
print("Could not pick default device.") print("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE)
current_loaded_models = [] current_loaded_models = []
@ -302,16 +313,15 @@ def unload_model_clones(model):
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False unloaded_model = False
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
if DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
current_free_mem = 0 if get_free_memory(device) > memory_required:
else: break
current_free_mem = get_free_memory(device)
if current_free_mem > memory_required:
break
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
current_loaded_models.pop(i).model_unload() m = current_loaded_models.pop(i)
m.model_unload()
del m
unloaded_model = True unloaded_model = True
if unloaded_model: if unloaded_model:
@ -394,6 +404,12 @@ def cleanup_models():
x.model_unload() x.model_unload()
del x del x
def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
return dtype_size
def unet_offload_device(): def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
return get_torch_device() return get_torch_device()
@ -409,11 +425,7 @@ def unet_inital_load_device(parameters, dtype):
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
return cpu_dev return cpu_dev
dtype_size = 4 model_size = dtype_size(dtype) * parameters
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
model_size = dtype_size * parameters
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)
@ -432,8 +444,7 @@ def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU if should_use_fp16(prioritize_performance=False):
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
@ -450,12 +461,8 @@ def vae_offload_device():
return torch.device("cpu") return torch.device("cpu")
def vae_dtype(): def vae_dtype():
if args.fp16_vae: global VAE_DTYPE
return torch.float16 return VAE_DTYPE
elif args.bf16_vae:
return torch.bfloat16
else:
return torch.float32
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
@ -569,15 +576,19 @@ def is_device_mps(device):
return True return True
return False return False
def should_use_fp16(device=None, model_params=0): def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if device is not None:
if is_device_cpu(device):
return False
if FORCE_FP16: if FORCE_FP16:
return True return True
if device is not None: #TODO if device is not None: #TODO
if is_device_cpu(device) or is_device_mps(device): if is_device_mps(device):
return False return False
if FORCE_FP32: if FORCE_FP32:
@ -610,7 +621,7 @@ def should_use_fp16(device=None, model_params=0):
if fp16_works: if fp16_works:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True
if props.major < 7: if props.major < 7:
@ -636,6 +647,13 @@ def soft_empty_cache():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
import threading import threading

270
comfy/model_patcher.py Normal file
View File

@ -0,0 +1,270 @@
import torch
import copy
import inspect
import comfy.utils
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to

View File

@ -22,7 +22,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..nodes.common import MAX_RESOLUTION from ..nodes.common import MAX_RESOLUTION
import comfy.controlnet
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -226,14 +226,16 @@ class VAEDecode:
class VAEDecodeTiled: class VAEDecodeTiled:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64})
}}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def decode(self, vae, samples): def decode(self, vae, samples, tile_size):
return (vae.decode_tiled(samples["samples"]), ) return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
class VAEEncode: class VAEEncode:
@classmethod @classmethod
@ -262,15 +264,17 @@ class VAEEncode:
class VAEEncodeTiled: class VAEEncodeTiled:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels): def encode(self, vae, pixels, tile_size):
pixels = VAEEncode.vae_encode_crop_pixels(pixels) pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3]) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
@ -552,7 +556,7 @@ class ControlNetLoader:
def load_controlnet(self, control_net_name): def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = sd.load_controlnet(controlnet_path) controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,) return (controlnet,)
class DiffControlNetLoader: class DiffControlNetLoader:
@ -568,7 +572,7 @@ class DiffControlNetLoader:
def load_controlnet(self, model, control_net_name): def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = sd.load_controlnet(controlnet_path, model) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,) return (controlnet,)
@ -1292,7 +1296,7 @@ class LoadImage:
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required": return {"required":
{"image": (sorted(files), )}, {"image": (sorted(files), {"image_upload": True})},
} }
CATEGORY = "image" CATEGORY = "image"
@ -1335,7 +1339,7 @@ class LoadImageMask:
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required": return {"required":
{"image": (sorted(files), ), {"image": (sorted(files), {"image_upload": True}),
"channel": (s._color_channels, ), } "channel": (s._color_channels, ), }
} }

View File

@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager @contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear old_torch_nn_linear = torch.nn.Linear
torch.nn.Linear = Linear force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
torch.nn.Linear = linear_with_dtype
try: try:
yield yield
finally: finally:

View File

@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]] models += [c[1][model_type]]
return models return models
def get_additional_models(positive, negative): def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning""" """loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
inference_memory = 0
control_models = [] control_models = []
for m in control_nets: for m in control_nets:
control_models += m.get_models() control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
models = control_models + gligen models = control_models + gligen
return models return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" """cleanup additional models that were loaded"""
@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None real_model = None
models = get_additional_models(positive, negative) models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
model_management.load_models_gpu([model] + models, model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) model_management.load_models_gpu([model] + models, model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12 self.num_layers = 12
@ -57,17 +57,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json') textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
config = CLIPTextConfig.from_json_file(textmodel_json_config) config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
with ops.use_comfy_ops(): with ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76] self.empty_tokens = [[49406] + [49407] * 76]
self.text_projection = None self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = True self.layer_norm_hidden_state = True
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
@ -140,9 +144,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if backup_embeds.weight.dtype != torch.float32: if backup_embeds.weight.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = contextlib.nullcontext precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device)): with precision_scope(model_management.get_autocast_device(device), torch.float32):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds) self.transformer.set_input_embeddings(backup_embeds)
@ -157,13 +161,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
pooled_output = outputs.pooler_output pooled_output = outputs.pooler_output
if self.text_projection is not None: if self.text_projection is not None:
pooled_output = pooled_output.to(self.text_projection.device) @ self.text_projection pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float() return z.float(), pooled_output.float()
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
def load_sd(self, sd): def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False) return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string): def parse_parentheses(string):

View File

@ -1,11 +1,10 @@
from pkg_resources import resource_filename from pkg_resources import resource_filename
from . import sd1_clip from . import sd1_clip
import torch
import os import os
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=23 layer_idx=23
@ -13,7 +12,7 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config): if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json') textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75] self.empty_tokens = [[49406] + [49407] + [0] * 75]
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):

View File

@ -17,7 +17,7 @@
"num_attention_heads": 16, "num_attention_heads": 16,
"num_hidden_layers": 24, "num_hidden_layers": 24,
"pad_token_id": 1, "pad_token_id": 1,
"projection_dim": 512, "projection_dim": 1024,
"torch_dtype": "float32", "torch_dtype": "float32",
"vocab_size": 49408 "vocab_size": 49408
} }

View File

@ -3,23 +3,17 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SD1ClipModel): class SDXLClipG(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75] self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = False self.layer_norm_hidden_state = False
def load_sd(self, sd): def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return super().load_sd(sd) return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
@ -42,11 +36,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu"): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device) self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx) self.clip_l.clip_layer(layer_idx)
@ -70,9 +64,9 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd) return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module): class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_g = SDXLClipG(device=device) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx) self.clip_g.clip_layer(layer_idx)

View File

@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
class Adapter(nn.Module): class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
super(Adapter, self).__init__() super(Adapter, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8) self.unshuffle_amount = 8
resblock_no_downsample = []
resblock_downsample = [3, 2, 1]
self.xl = xl
if self.xl:
self.unshuffle_amount = 16
resblock_no_downsample = [1]
resblock_downsample = [2]
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
self.channels = channels self.channels = channels
self.nums_rb = nums_rb self.nums_rb = nums_rb
self.body = [] self.body = []
for i in range(len(channels)): for i in range(len(channels)):
for j in range(nums_rb): for j in range(nums_rb):
if (i != 0) and (j == 0): if (i in resblock_downsample) and (j == 0):
self.body.append( self.body.append(
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
elif (i in resblock_no_downsample) and (j == 0):
self.body.append(
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
else: else:
self.body.append( self.body.append(
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
@ -128,8 +141,16 @@ class Adapter(nn.Module):
for j in range(self.nums_rb): for j in range(self.nums_rb):
idx = i * self.nums_rb + j idx = i * self.nums_rb + j
x = self.body[idx](x) x = self.body[idx](x)
features.append(None) if self.xl:
features.append(None) features.append(None)
if i == 0:
features.append(None)
features.append(None)
if i == 2:
features.append(None)
else:
features.append(None)
features.append(None)
features.append(x) features.append(x)
return features return features
@ -243,10 +264,14 @@ class extractor(nn.Module):
class Adapter_light(nn.Module): class Adapter_light(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
super(Adapter_light, self).__init__() super(Adapter_light, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8) self.unshuffle_amount = 8
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
self.channels = channels self.channels = channels
self.nums_rb = nums_rb self.nums_rb = nums_rb
self.body = [] self.body = []
self.xl = False
for i in range(len(channels)): for i in range(len(channels)):
if i == 0: if i == 0:
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))

View File

@ -34,6 +34,13 @@ def save_torch_file(sd, ckpt, metadata=None):
else: else:
safetensors.torch.save_file(sd, ckpt) safetensors.torch.save_file(sd, ckpt)
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = { keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight", "{}positional_embedding": "{}embeddings.position_embedding.weight",
@ -232,6 +239,20 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None return None
return f.read(length_of_header) return f.read(length_of_header)
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
def bislerp(samples, width, height): def bislerp(samples, width, height):
def slerp(b1, b2, r): def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''

View File

@ -0,0 +1,167 @@
import {app} from "../../scripts/app.js";
function setNodeMode(node, mode) {
node.mode = mode;
node.graph.change();
}
app.registerExtension({
name: "Comfy.GroupOptions",
setup() {
const orig = LGraphCanvas.prototype.getCanvasMenuOptions;
// graph_mouse
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) {
return options;
}
// Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date
group.recomputeInsideNodes();
const nodesInGroup = group._nodes;
// No nodes in group, return default options
if (nodesInGroup.length === 0) {
return options;
} else {
// Add a separator between the default options and the group options
options.push(null);
}
// Check if all nodes are the same mode
let allNodesAreSameMode = true;
for (let i = 1; i < nodesInGroup.length; i++) {
if (nodesInGroup[i].mode !== nodesInGroup[0].mode) {
allNodesAreSameMode = false;
break;
}
}
// Modes
// 0: Always
// 1: On Event
// 2: Never
// 3: On Trigger
// 4: Bypass
// If all nodes are the same mode, add a menu option to change the mode
if (allNodesAreSameMode) {
const mode = nodesInGroup[0].mode;
switch (mode) {
case 0:
// All nodes are always, option to disable, and bypass
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
case 2:
// All nodes are never, option to enable, and bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
case 4:
// All nodes are bypass, option to enable, and disable
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
break;
default:
// All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
break;
}
} else {
// Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass
options.push({
content: "Set Group Nodes to Always",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 0);
}
}
});
options.push({
content: "Set Group Nodes to Never",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 2);
}
}
});
options.push({
content: "Bypass Group Nodes",
callback: () => {
for (const node of nodesInGroup) {
setNodeMode(node, 4);
}
}
});
}
return options
}
}
});

View File

@ -5,7 +5,7 @@ import { app } from "../../scripts/app.js";
app.registerExtension({ app.registerExtension({
name: "Comfy.UploadImage", name: "Comfy.UploadImage",
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "LoadImage" || nodeData.name === "LoadImageMask") { if (nodeData?.input?.required?.image?.[1]?.image_upload === true) {
nodeData.input.required.upload = ["IMAGEUPLOAD"]; nodeData.input.required.upload = ["IMAGEUPLOAD"];
} }
}, },

View File

@ -299,11 +299,17 @@ export const ComfyWidgets = {
const defaultVal = inputData[1].default || ""; const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline; const multiline = !!inputData[1].multiline;
let res;
if (multiline) { if (multiline) {
return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
} else { } else {
return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
} }
if(inputData[1].dynamicPrompts != undefined)
res.widget.dynamicPrompts = inputData[1].dynamicPrompts;
return res;
}, },
COMBO(node, inputName, inputData) { COMBO(node, inputName, inputData) {
const type = inputData[0]; const type = inputData[0];