diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4f22f319..fda245433 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") -fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") +fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 596a1b718..7ae4088f7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,14 +2,28 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm from .utils import load_torch_file, transformers_convert import os import torch +import contextlib from . import ops +import comfy.ops +import comfy.model_patcher +import comfy.model_management + class ClipVisionModel(): def __init__(self, json_config): config = CLIPVisionConfig.from_json_file(json_config) - with ops.use_comfy_ops(): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + + with ops.use_comfy_ops(offload_device, self.dtype): with modeling_utils.no_init_weights(): self.model = CLIPVisionModelWithProjection(config) + self.model.to(self.dtype) + + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.processor = CLIPImageProcessor(crop_size=224, do_center_crop=True, do_convert_rgb=True, @@ -27,7 +41,21 @@ class ClipVisionModel(): img = torch.clip((255. * image), 0, 255).round().int() img = list(map(lambda a: a, img)) inputs = self.processor(images=img, return_tensors="pt") - outputs = self.model(**inputs) + comfy.model_management.load_model_gpu(self.patcher) + pixel_values = inputs['pixel_values'].to(self.load_device) + + if self.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, b: contextlib.nullcontext(a) + + with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): + outputs = self.model(pixel_values=pixel_values) + + for k in outputs: + t = outputs[k] + if t is not None: + outputs[k] = t.cpu() return outputs def convert_to_transformers(sd, prefix): diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 04a3750d8..0760ab9b5 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -1,5 +1,6 @@ from __future__ import annotations import asyncio +import traceback import glob import struct import sys @@ -7,6 +8,7 @@ import shutil from urllib.parse import quote from PIL import Image, ImageOps +from PIL.PngImagePlugin import PngInfo from io import BytesIO import json @@ -98,7 +100,7 @@ class PromptServer(): if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) - self.app = web.Application(client_max_size=20971520, handler_args={'max_field_size': 16380}, + self.app = web.Application(client_max_size=104857600, handler_args={'max_field_size': 16380}, middlewares=middlewares) self.sockets = dict() web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web") @@ -111,6 +113,8 @@ class PromptServer(): self.last_node_id = None self.client_id = None + self.on_prompt_handlers = [] + @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse() @@ -252,13 +256,17 @@ class PromptServer(): if os.path.isfile(file): with Image.open(file) as original_pil: + metadata = PngInfo() + if hasattr(original_pil,'text'): + for key in original_pil.text: + metadata.add_text(key, original_pil.text[key]) original_pil = original_pil.convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA') # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath, compress_level=4) + original_pil.save(filepath, compress_level=4, pnginfo=metadata) return image_upload(post, image_save_function) @@ -463,6 +471,7 @@ class PromptServer(): resp_code = 200 out_string = "" json_data = await request.json() + json_data = self.trigger_on_prompt(json_data) if "number" in json_data: number = float(json_data['number']) @@ -761,6 +770,19 @@ class PromptServer(): if call_on_start is not None: call_on_start(address, port) + def add_on_prompt_handler(self, handler): + self.on_prompt_handlers.append(handler) + + def trigger_on_prompt(self, json_data): + for handler in self.on_prompt_handlers: + try: + json_data = handler(json_data) + except Exception as e: + print(f"[ERROR] An error occurred during the on_prompt_handler processing") + traceback.print_exc() + + return json_data + @classmethod def get_output_path(cls, subfolder: str | None = None, filename: str | None = None): paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""] diff --git a/comfy/controlnet.py b/comfy/controlnet.py new file mode 100644 index 000000000..7098186f9 --- /dev/null +++ b/comfy/controlnet.py @@ -0,0 +1,480 @@ +import torch +import math +import os +import comfy.utils +import comfy.model_management +import comfy.model_detection +import comfy.model_patcher + +import comfy.cldm.cldm +import comfy.t2i_adapter.adapter + + +def broadcast_image_to(tensor, target_batch_size, batched_number): + current_batch_size = tensor.shape[0] + #print(current_batch_size, target_batch_size) + if current_batch_size == 1: + return tensor + + per_batch = target_batch_size // batched_number + tensor = tensor[:per_batch] + + if per_batch > tensor.shape[0]: + tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) + + current_batch_size = tensor.shape[0] + if current_batch_size == target_batch_size: + return tensor + else: + return torch.cat([tensor] * batched_number, dim=0) + +class ControlBase: + def __init__(self, device=None): + self.cond_hint_original = None + self.cond_hint = None + self.strength = 1.0 + self.timestep_percent_range = (1.0, 0.0) + self.timestep_range = None + + if device is None: + device = comfy.model_management.get_torch_device() + self.device = device + self.previous_controlnet = None + self.global_average_pooling = False + + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.cond_hint_original = cond_hint + self.strength = strength + self.timestep_percent_range = timestep_percent_range + return self + + def pre_run(self, model, percent_to_timestep_function): + self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) + if self.previous_controlnet is not None: + self.previous_controlnet.pre_run(model, percent_to_timestep_function) + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.timestep_range = None + + def get_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_models() + return out + + def copy_to(self, c): + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + c.timestep_percent_range = self.timestep_percent_range + + def inference_memory_requirements(self, dtype): + if self.previous_controlnet is not None: + return self.previous_controlnet.inference_memory_requirements(dtype) + return 0 + + def control_merge(self, control_input, control_output, control_prev, output_dtype): + out = {'input':[], 'middle':[], 'output': []} + + if control_input is not None: + for i in range(len(control_input)): + key = 'input' + x = control_input[i] + if x is not None: + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + out[key].insert(0, x) + + if control_output is not None: + for i in range(len(control_output)): + if i == (len(control_output) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i + x = control_output[i] + if x is not None: + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + + out[key].append(x) + if control_prev is not None: + for x in ['input', 'middle', 'output']: + o = out[x] + for i in range(len(control_prev[x])): + prev_val = control_prev[x][i] + if i >= len(o): + o.append(prev_val) + elif prev_val is not None: + if o[i] is None: + o[i] = prev_val + else: + o[i] += prev_val + return out + +class ControlNet(ControlBase): + def __init__(self, control_model, global_average_pooling=False, device=None): + super().__init__(device) + self.control_model = control_model + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.global_average_pooling = global_average_pooling + + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + + output_dtype = x_noisy.dtype + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + + + context = torch.cat(cond['c_crossattn'], 1) + y = cond.get('c_adm', None) + if y is not None: + y = y.to(self.control_model.dtype) + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + return self.control_merge(None, control, control_prev, output_dtype) + + def copy(self): + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def get_models(self): + out = super().get_models() + out.append(self.control_model_wrapped) + return out + +class ControlLoraOps: + class Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.up = None + self.down = None + self.bias = None + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + else: + return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + + class Conv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = False + self.output_padding = 0 + self.groups = groups + self.padding_mode = padding_mode + + self.weight = None + self.bias = None + self.up = None + self.down = None + + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + else: + return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + + def conv_nd(self, dims, *args, **kwargs): + if dims == 2: + return self.Conv2d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class ControlLora(ControlNet): + def __init__(self, control_weights, global_average_pooling=False, device=None): + ControlBase.__init__(self, device) + self.control_weights = control_weights + self.global_average_pooling = global_average_pooling + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + controlnet_config = model.model_config.unet_config.copy() + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] + controlnet_config["operations"] = ControlLoraOps() + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + dtype = model.get_dtype() + self.control_model.to(dtype) + self.control_model.to(comfy.model_management.get_torch_device()) + diffusion_model = model.diffusion_model + sd = diffusion_model.state_dict() + cm = self.control_model.state_dict() + + for k in sd: + weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) + try: + comfy.utils.set_attr(self.control_model, k, weight) + except: + pass + + for k in self.control_weights: + if k not in {"lora_controlnet"}: + comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + + def copy(self): + c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def cleanup(self): + del self.control_model + self.control_model = None + super().cleanup() + + def get_models(self): + out = ControlBase.get_models(self) + return out + + def inference_memory_requirements(self, dtype): + return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + +def load_controlnet(ckpt_path, model=None): + controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + if "lora_controlnet" in controlnet_data: + return ControlLora(controlnet_data) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + use_fp16 = comfy.model_management.should_use_fp16() + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) + diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + print("leftover keys:", leftover_keys) + controlnet_data = new_sd + + pth_key = 'control_model.zero_convs.0.0.weight' + pth = False + key = 'zero_convs.0.0.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + prefix = "control_model." + elif key in controlnet_data: + prefix = "" + else: + net = load_t2i_adapter(controlnet_data) + if net is None: + print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + return net + + if controlnet_config is None: + use_fp16 = comfy.model_management.should_use_fp16() + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] + control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + + if pth: + if 'difference' in controlnet_data: + if model is not None: + comfy.model_management.load_models_gpu([model]) + model_sd = model.model_state_dict() + for x in controlnet_data: + c_m = "control_model." + if x.startswith(c_m): + sd_key = "diffusion_model.{}".format(x[len(c_m):]) + if sd_key in model_sd: + cd = controlnet_data[x] + cd += model_sd[sd_key].type(cd.dtype).to(cd.device) + else: + print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + else: + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + print(missing, unexpected) + + if use_fp16: + control_model = control_model.half() + + global_average_pooling = False + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) + return control + +class T2IAdapter(ControlBase): + def __init__(self, t2i_model, channels_in, device=None): + super().__init__(device) + self.t2i_model = t2i_model + self.channels_in = channels_in + self.control_input = None + + def scale_image_to(self, width, height): + unshuffle_amount = self.t2i_model.unshuffle_amount + width = math.ceil(width / unshuffle_amount) * unshuffle_amount + height = math.ceil(height / unshuffle_amount) * unshuffle_amount + return width, height + + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.control_input = None + self.cond_hint = None + width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + if self.channels_in == 1 and self.cond_hint.shape[1] > 1: + self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: + self.t2i_model.to(x_noisy.dtype) + self.t2i_model.to(self.device) + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + self.t2i_model.cpu() + + control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) + mid = None + if self.t2i_model.xl == True: + mid = control_input[-1:] + control_input = control_input[:-1] + return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) + + def copy(self): + c = T2IAdapter(self.t2i_model, self.channels_in) + self.copy_to(c) + return c + +def load_t2i_adapter(t2i_data): + keys = t2i_data.keys() + if 'adapter' in keys: + t2i_data = t2i_data['adapter'] + keys = t2i_data.keys() + if "body.0.in_conv.weight" in keys: + cin = t2i_data['body.0.in_conv.weight'].shape[1] + model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) + elif 'conv_in.weight' in keys: + cin = t2i_data['conv_in.weight'].shape[1] + channel = t2i_data['conv_in.weight'].shape[0] + ksize = t2i_data['body.0.block2.weight'].shape[2] + use_conv = False + down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) + if len(down_opts) > 0: + use_conv = True + xl = False + if cin == 256: + xl = True + model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + else: + return None + missing, unexpected = model_ad.load_state_dict(t2i_data) + if len(missing) > 0: + print("t2i missing", missing) + + if len(unexpected) > 0: + print("t2i unexpected", unexpected) + + return T2IAdapter(model_ad, model_ad.input_channels) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 4c396fab3..960a1ed6a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -56,7 +56,18 @@ class Upsample(nn.Module): padding=1) def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + try: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + except: #operation not implemented for bf16 + b, c, h, w = x.shape + out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device) + split = 8 + l = out.shape[1] // split + for i in range(0, out.shape[1], l): + out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype) + del x + x = out + if self.with_conv: x = self.conv(x) return x @@ -74,11 +85,10 @@ class Downsample(nn.Module): stride=2, padding=0) - def forward(self, x, already_padded=False): + def forward(self, x): if self.with_conv: - if not already_padded: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -275,25 +285,17 @@ class MemoryEfficientAttnBlock(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), (q, k, v), ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + try: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + out = out.transpose(1, 2).reshape(B, C, H, W) + except NotImplementedError as e: + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out @@ -603,9 +605,6 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - already_padded = True # downsampling h = self.conv_in(x) for i_level in range(self.num_resolutions): @@ -614,8 +613,7 @@ class Encoder(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h, already_padded) - already_padded = False + h = self.down[i_level].downsample(h) # middle h = self.mid.block_1(h, temb) diff --git a/comfy/lora.py b/comfy/lora.py new file mode 100644 index 000000000..3009a1c9e --- /dev/null +++ b/comfy/lora.py @@ -0,0 +1,199 @@ +import comfy.utils + +LORA_CLIP_MAP = { + "mlp.fc1": "mlp_fc1", + "mlp.fc2": "mlp_fc2", + "self_attn.k_proj": "self_attn_k_proj", + "self_attn.q_proj": "self_attn_q_proj", + "self_attn.v_proj": "self_attn_v_proj", + "self_attn.out_proj": "self_attn_out_proj", +} + + +def load_lora(lora, to_load): + patch_dict = {} + loaded_keys = set() + for x in to_load: + alpha_name = "{}.alpha".format(x) + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name ="{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + + + ######## loha + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + + + w_norm_name = "{}.w_norm".format(x) + b_norm_name = "{}.b_norm".format(x) + w_norm = lora.get(w_norm_name, None) + b_norm = lora.get(b_norm_name, None) + + if w_norm is not None: + loaded_keys.add(w_norm_name) + patch_dict[to_load[x]] = (w_norm,) + if b_norm is not None: + loaded_keys.add(b_norm_name) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + + for x in lora.keys(): + if x not in loaded_keys: + print("lora key not loaded", x) + return patch_dict + +def model_lora_keys_clip(model, key_map={}): + sdk = model.state_dict().keys() + + text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" + clip_l_present = False + for b in range(32): + for c in LORA_CLIP_MAP: + k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + clip_l_present = True + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + if clip_l_present: + lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + else: + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + return key_map + +def model_lora_keys_unet(model, key_map={}): + sdk = model.state_dict().keys() + + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + + diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) + for k in diffusers_keys: + if k.endswith(".weight"): + unet_key = "diffusion_model.{}".format(diffusers_keys[k]) + key_lora = k[:-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = unet_key + + diffusers_lora_prefix = ["", "unet."] + for p in diffusers_lora_prefix: + diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) + if diffusers_lora_key.endswith(".to_out.0"): + diffusers_lora_key = diffusers_lora_key[:-2] + key_map[diffusers_lora_key] = unet_key + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 4833f36e3..72e9bd2c0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -3,6 +3,7 @@ from .ldm.modules.diffusionmodules.openaimodel import UNetModel from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.modules.diffusionmodules.util import make_beta_schedule from .ldm.modules.diffusionmodules.openaimodel import Timestep +import comfy.model_management import numpy as np from enum import Enum from . import utils @@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_state_dict = self.diffusion_model.state_dict() + unet_sd = self.diffusion_model.state_dict() + unet_state_dict = {} + for k in unet_sd: + unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: diff --git a/comfy/model_management.py b/comfy/model_management.py index 08a77ccc9..e4958f2dc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from .cli_args import args +import comfy.utils import torch import sys @@ -111,9 +112,6 @@ if not args.normalvram and not args.cpu: if lowvram_available and total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError @@ -150,15 +148,27 @@ def is_nvidia(): return True ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +VAE_DTYPE = torch.float32 -if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - try: - if is_nvidia(): - torch_version = torch.version.__version__ - if int(torch_version[0]) >= 2: + +try: + if is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - except: - pass + if torch.cuda.is_bf16_supported(): + VAE_DTYPE = torch.bfloat16 +except: + pass + +if args.fp16_vae: + VAE_DTYPE = torch.float16 +elif args.bf16_vae: + VAE_DTYPE = torch.bfloat16 +elif args.fp32_vae: + VAE_DTYPE = torch.float32 + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) @@ -230,6 +240,7 @@ try: except: print("Could not pick default device.") +print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] @@ -302,16 +313,15 @@ def unload_model_clones(model): def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False for i in range(len(current_loaded_models) -1, -1, -1): - if DISABLE_SMART_MEMORY: - current_free_mem = 0 - else: - current_free_mem = get_free_memory(device) - if current_free_mem > memory_required: - break + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - current_loaded_models.pop(i).model_unload() + m = current_loaded_models.pop(i) + m.model_unload() + del m unloaded_model = True if unloaded_model: @@ -394,6 +404,12 @@ def cleanup_models(): x.model_unload() del x +def dtype_size(dtype): + dtype_size = 4 + if dtype == torch.float16 or dtype == torch.bfloat16: + dtype_size = 2 + return dtype_size + def unet_offload_device(): if vram_state == VRAMState.HIGH_VRAM: return get_torch_device() @@ -409,11 +425,7 @@ def unet_inital_load_device(parameters, dtype): if DISABLE_SMART_MEMORY: return cpu_dev - dtype_size = 4 - if dtype == torch.float16 or dtype == torch.bfloat16: - dtype_size = 2 - - model_size = dtype_size * parameters + model_size = dtype_size(dtype) * parameters mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) @@ -432,8 +444,7 @@ def text_encoder_device(): if args.gpu_only: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: - #NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU - if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. + if should_use_fp16(prioritize_performance=False): return get_torch_device() else: return torch.device("cpu") @@ -450,12 +461,8 @@ def vae_offload_device(): return torch.device("cpu") def vae_dtype(): - if args.fp16_vae: - return torch.float16 - elif args.bf16_vae: - return torch.bfloat16 - else: - return torch.float32 + global VAE_DTYPE + return VAE_DTYPE def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -569,15 +576,19 @@ def is_device_mps(device): return True return False -def should_use_fp16(device=None, model_params=0): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True): global xpu_available global directml_enabled + if device is not None: + if is_device_cpu(device): + return False + if FORCE_FP16: return True if device is not None: #TODO - if is_device_cpu(device) or is_device_mps(device): + if is_device_mps(device): return False if FORCE_FP32: @@ -610,7 +621,7 @@ def should_use_fp16(device=None, model_params=0): if fp16_works: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) - if model_params * 4 > free_model_memory: + if (not prioritize_performance) or model_params * 4 > free_model_memory: return True if props.major < 7: @@ -636,6 +647,13 @@ def soft_empty_cache(): torch.cuda.empty_cache() torch.cuda.ipc_collect() +def resolve_lowvram_weight(weight, model, key): + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = comfy.utils.get_attr(model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + return weight + #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py new file mode 100644 index 000000000..2f087a600 --- /dev/null +++ b/comfy/model_patcher.py @@ -0,0 +1,270 @@ +import torch +import copy +import inspect + +import comfy.utils + +class ModelPatcher: + def __init__(self, model, load_device, offload_device, size=0, current_device=None): + self.size = size + self.model = model + self.patches = {} + self.backup = {} + self.model_options = {"transformer_options":{}} + self.model_size() + self.load_device = load_device + self.offload_device = offload_device + if current_device is None: + self.current_device = self.offload_device + else: + self.current_device = current_device + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + self.model_keys = set(model_sd.keys()) + return size + + def clone(self): + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + + n.model_options = copy.deepcopy(self.model_options) + n.model_keys = self.model_keys + return n + + def is_clone(self, other): + if hasattr(other, 'model') and self.model is other.model: + return True + return False + + def set_model_sampler_cfg_function(self, sampler_cfg_function): + if len(inspect.signature(sampler_cfg_function).parameters) == 3: + self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way + else: + self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_unet_function_wrapper(self, unet_wrapper_function): + self.model_options["model_function_wrapper"] = unet_wrapper_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_patch_replace(self, patch, name, block_name, number): + to = self.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][(block_name, number)] = patch + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn1", block_name, number) + + def set_model_attn2_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn2", block_name, number) + + def set_model_attn1_output_patch(self, patch): + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch): + self.set_model_patch(patch, "attn2_output_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + patch_list[k] = patch_list[k].to(device) + + def model_dtype(self): + if hasattr(self.model, "get_dtype"): + return self.model.get_dtype() + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + p = set() + for k in patches: + if k in self.model_keys: + p.add(k) + current_patches = self.patches.get(k, []) + current_patches.append((strength_patch, patches[k], strength_model)) + self.patches[k] = current_patches + + return list(p) + + def get_key_patches(self, filter_prefix=None): + model_sd = self.model_state_dict() + p = {} + for k in model_sd: + if filter_prefix is not None: + if not k.startswith(filter_prefix): + continue + if k in self.patches: + p[k] = [model_sd[k]] + self.patches[k] + else: + p[k] = (model_sd[k],) + return p + + def model_state_dict(self, filter_prefix=None): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd + + def patch_model(self, device_to=None): + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", k) + continue + + weight = model_sd[key] + + if key not in self.backup: + self.backup[key] = weight.to(self.offload_device) + + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + comfy.utils.set_attr(self.model, key, out_weight) + del temp_weight + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + return self.model + + def calculate_weight(self, patches, weight, key): + for p in patches: + alpha = p[0] + v = p[1] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key), ) + + if len(v) == 1: + w1 = v[0] + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon + mat1 = v[0].float().to(weight.device) + mat2 = v[1].float().to(weight.device) + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = v[3].float().to(weight.device) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + else: + w1 = w1.float().to(weight.device) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + else: + w2 = w2.float().to(weight.device) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + else: #loha + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) + else: + m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) + m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + + return weight + + def unpatch_model(self, device_to=None): + keys = list(self.backup.keys()) + + for k in keys: + comfy.utils.set_attr(self.model, k, self.backup[k]) + + self.backup = {} + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 2bfa7b850..a4bcc1729 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -22,7 +22,7 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..nodes.common import MAX_RESOLUTION - +import comfy.controlnet class CLIPTextEncode: @classmethod @@ -226,14 +226,16 @@ class VAEDecode: class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), + "tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64}) + }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "_for_testing" - def decode(self, vae, samples): - return (vae.decode_tiled(samples["samples"]), ) + def decode(self, vae, samples, tile_size): + return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), ) class VAEEncode: @classmethod @@ -262,15 +264,17 @@ class VAEEncode: class VAEEncodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} + return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), + "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) + }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" - def encode(self, vae, pixels): + def encode(self, vae, pixels, tile_size): pixels = VAEEncode.vae_encode_crop_pixels(pixels) - t = vae.encode_tiled(pixels[:,:,:,:3]) + t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) class VAEEncodeForInpaint: @@ -552,7 +556,7 @@ class ControlNetLoader: def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) - controlnet = sd.load_controlnet(controlnet_path) + controlnet = comfy.controlnet.load_controlnet(controlnet_path) return (controlnet,) class DiffControlNetLoader: @@ -568,7 +572,7 @@ class DiffControlNetLoader: def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) - controlnet = sd.load_controlnet(controlnet_path, model) + controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) return (controlnet,) @@ -1292,7 +1296,7 @@ class LoadImage: input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(files), )}, + {"image": (sorted(files), {"image_upload": True})}, } CATEGORY = "image" @@ -1335,7 +1339,7 @@ class LoadImageMask: input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(files), ), + {"image": (sorted(files), {"image_upload": True}), "channel": (s._color_channels, ), } } diff --git a/comfy/ops.py b/comfy/ops.py index 678c2c6d0..610d54584 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs): raise ValueError(f"unsupported dimensions: {dims}") @contextmanager -def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way +def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear - torch.nn.Linear = Linear + force_device = device + force_dtype = dtype + def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + if force_device is not None: + device = force_device + if force_dtype is not None: + dtype = force_dtype + return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + + torch.nn.Linear = linear_with_dtype try: yield finally: diff --git a/comfy/sample.py b/comfy/sample.py index 1e95365f9..1d95b200d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type): models += [c[1][model_type]] return models -def get_additional_models(positive, negative): +def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) + inference_memory = 0 control_models = [] for m in control_nets: control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = [x[1] for x in gligen] models = control_models + gligen - return models + return models, inference_memory def cleanup_additional_models(models): """cleanup additional models that were loaded""" @@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative noise_mask = prepare_mask(noise_mask, noise.shape, device) real_model = None - models = get_additional_models(positive, negative) - model_management.load_models_gpu([model] + models, model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) + models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) + model_management.load_models_gpu([model] + models, model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) real_model = model.model noise = noise.to(device) diff --git a/comfy/sd.py b/comfy/sd.py index 764ea6819..6f434e5bb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,16 +1,14 @@ import torch import contextlib -import copy -import inspect +import math from . import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL import yaml -from .cldm import cldm -from .t2i_adapter import adapter -from . import utils +import comfy.utils + from . import clip_vision from . import gligen from . import diffusers_convert @@ -21,6 +19,10 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip +import comfy.model_patcher +import comfy.lora +import comfy.t2i_adapter.adapter + def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) m = set(m) @@ -47,479 +49,14 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return load_model_weights(model, sd) -LORA_CLIP_MAP = { - "mlp.fc1": "mlp_fc1", - "mlp.fc2": "mlp_fc2", - "self_attn.k_proj": "self_attn_k_proj", - "self_attn.q_proj": "self_attn_q_proj", - "self_attn.v_proj": "self_attn_v_proj", - "self_attn.out_proj": "self_attn_out_proj", -} - - -def load_lora(lora, to_load): - patch_dict = {} - loaded_keys = set() - for x in to_load: - alpha_name = "{}.alpha".format(x) - alpha = None - if alpha_name in lora.keys(): - alpha = lora[alpha_name].item() - loaded_keys.add(alpha_name) - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) - - for x in lora.keys(): - if x not in loaded_keys: - print("lora key not loaded", x) - return patch_dict - -def model_lora_keys_clip(model, key_map={}): - sdk = model.state_dict().keys() - - text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" - clip_l_present = False - for b in range(32): - for c in LORA_CLIP_MAP: - k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base - key_map[lora_key] = k - clip_l_present = True - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - if clip_l_present: - lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base - key_map[lora_key] = k - lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - else: - lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - return key_map - -def model_lora_keys_unet(model, key_map={}): - sdk = model.state_dict().keys() - - for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - - diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) - for k in diffusers_keys: - if k.endswith(".weight"): - unet_key = "diffusion_model.{}".format(diffusers_keys[k]) - key_lora = k[:-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = unet_key - - diffusers_lora_prefix = ["", "unet."] - for p in diffusers_lora_prefix: - diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) - if diffusers_lora_key.endswith(".to_out.0"): - diffusers_lora_key = diffusers_lora_key[:-2] - key_map[diffusers_lora_key] = unet_key - return key_map - -def set_attr(obj, attr, value): - attrs = attr.split(".") - for name in attrs[:-1]: - obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) - del prev - -def get_attr(obj, attr): - attrs = attr.split(".") - for name in attrs: - obj = getattr(obj, name) - return obj - - -class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None): - self.size = size - self.model = model - self.patches = {} - self.backup = {} - self.model_options = {"transformer_options":{}} - self.model_size() - self.load_device = load_device - self.offload_device = offload_device - if current_device is None: - self.current_device = self.offload_device - else: - self.current_device = current_device - - def model_size(self): - if self.size > 0: - return self.size - model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size - self.model_keys = set(model_sd.keys()) - return size - - def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - - n.model_options = copy.deepcopy(self.model_options) - n.model_keys = self.model_keys - return n - - def is_clone(self, other): - if hasattr(other, 'model') and self.model is other.model: - return True - return False - - def set_model_sampler_cfg_function(self, sampler_cfg_function): - if len(inspect.signature(sampler_cfg_function).parameters) == 3: - self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way - else: - self.model_options["sampler_cfg_function"] = sampler_cfg_function - - def set_model_unet_function_wrapper(self, unet_wrapper_function): - self.model_options["model_function_wrapper"] = unet_wrapper_function - - def set_model_patch(self, patch, name): - to = self.model_options["transformer_options"] - if "patches" not in to: - to["patches"] = {} - to["patches"][name] = to["patches"].get(name, []) + [patch] - - def set_model_patch_replace(self, patch, name, block_name, number): - to = self.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch - - def set_model_attn1_patch(self, patch): - self.set_model_patch(patch, "attn1_patch") - - def set_model_attn2_patch(self, patch): - self.set_model_patch(patch, "attn2_patch") - - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) - - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) - - def set_model_attn1_output_patch(self, patch): - self.set_model_patch(patch, "attn1_output_patch") - - def set_model_attn2_output_patch(self, patch): - self.set_model_patch(patch, "attn2_output_patch") - - def model_patches_to(self, device): - to = self.model_options["transformer_options"] - if "patches" in to: - patches = to["patches"] - for name in patches: - patch_list = patches[name] - for i in range(len(patch_list)): - if hasattr(patch_list[i], "to"): - patch_list[i] = patch_list[i].to(device) - if "patches_replace" in to: - patches = to["patches_replace"] - for name in patches: - patch_list = patches[name] - for k in patch_list: - if hasattr(patch_list[k], "to"): - patch_list[k] = patch_list[k].to(device) - - def model_dtype(self): - if hasattr(self.model, "get_dtype"): - return self.model.get_dtype() - - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - for k in patches: - if k in self.model_keys: - p.add(k) - current_patches = self.patches.get(k, []) - current_patches.append((strength_patch, patches[k], strength_model)) - self.patches[k] = current_patches - - return list(p) - - def get_key_patches(self, filter_prefix=None): - model_sd = self.model_state_dict() - p = {} - for k in model_sd: - if filter_prefix is not None: - if not k.startswith(filter_prefix): - continue - if k in self.patches: - p[k] = [model_sd[k]] + self.patches[k] - else: - p[k] = (model_sd[k],) - return p - - def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd - - def patch_model(self, device_to=None): - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", k) - continue - - weight = model_sd[key] - - if key not in self.backup: - self.backup[key] = weight.to(self.offload_device) - - if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - - return self.model - - def calculate_weight(self, patches, weight, key): - for p in patches: - alpha = p[0] - v = p[1] - strength_model = p[2] - - if strength_model != 1.0: - weight *= strength_model - - if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key), ) - - if len(v) == 1: - w1 = v[0] - if alpha != 0.0: - if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += alpha * w1.type(weight.dtype).to(weight.device) - elif len(v) == 4: #lora/locon - mat1 = v[0].float().to(weight.device) - mat2 = v[1].float().to(weight.device) - if v[2] is not None: - alpha *= v[2] / mat2.shape[0] - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].float().to(weight.device) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - elif len(v) == 8: #lokr - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) - else: - w1 = w1.float().to(weight.device) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) - else: - w2 = w2.float().to(weight.device) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha *= v[2] / dim - - try: - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - else: #loha - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha *= v[2] / w1b.shape[0] - w2a = v[3] - w2b = v[4] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) - else: - m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) - m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) - - try: - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - - return weight - - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) - - for k in keys: - set_attr(self.model, k, self.backup[k]) - - self.backup = {} - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - def load_lora_for_models(model, clip, lora, strength_model, strength_clip): - key_map = model_lora_keys_unet(model.model) - key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) - loaded = load_lora(lora, key_map) + key_map = comfy.lora.model_lora_keys_unet(model.model) + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + loaded = comfy.lora.load_lora(lora, key_map) new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) new_clip = clip.clone() @@ -543,16 +80,16 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - params['device'] = load_device - self.cond_stage_model = clip(**(params)) - #TODO: make sure this doesn't have a quality loss before enabling. - # if model_management.should_use_fp16(load_device): - # self.cond_stage_model.half() + params['device'] = offload_device + if model_management.should_use_fp16(load_device, prioritize_performance=False): + params['dtype'] = torch.float16 + else: + params['dtype'] = torch.float32 - self.cond_stage_model = self.cond_stage_model.to() + self.cond_stage_model = clip(**(params)) self.tokenizer = tokenizer(embedding_directory=embedding_directory) - self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None def clone(self): @@ -563,9 +100,6 @@ class CLIP: n.layer_idx = self.layer_idx return n - def load_from_state_dict(self, sd): - self.cond_stage_model.load_sd(sd) - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) @@ -614,7 +148,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() if ckpt_path is not None: - sd = utils.load_torch_file(ckpt_path) + sd = comfy.utils.load_torch_file(ckpt_path) if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) self.first_stage_model.load_state_dict(sd, strict=False) @@ -627,29 +161,29 @@ class VAE: self.first_stage_model.to(self.vae_dtype) def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): - steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) + steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) + steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float() - samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples /= 3.0 return samples @@ -711,453 +245,6 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() - -def broadcast_image_to(tensor, target_batch_size, batched_number): - current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) - if current_batch_size == 1: - return tensor - - per_batch = target_batch_size // batched_number - tensor = tensor[:per_batch] - - if per_batch > tensor.shape[0]: - tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) - - current_batch_size = tensor.shape[0] - if current_batch_size == target_batch_size: - return tensor - else: - return torch.cat([tensor] * batched_number, dim=0) - -class ControlBase: - def __init__(self, device=None): - self.cond_hint_original = None - self.cond_hint = None - self.strength = 1.0 - self.timestep_percent_range = (1.0, 0.0) - self.timestep_range = None - - if device is None: - device = model_management.get_torch_device() - self.device = device - self.previous_controlnet = None - self.global_average_pooling = False - - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): - self.cond_hint_original = cond_hint - self.strength = strength - self.timestep_percent_range = timestep_percent_range - return self - - def pre_run(self, model, percent_to_timestep_function): - self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) - if self.previous_controlnet is not None: - self.previous_controlnet.pre_run(model, percent_to_timestep_function) - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - self.timestep_range = None - - def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() - return out - - def copy_to(self, c): - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength - c.timestep_percent_range = self.timestep_percent_range - - def control_merge(self, control_input, control_output, control_prev, output_dtype): - out = {'input':[], 'middle':[], 'output': []} - - if control_input is not None: - for i in range(len(control_input)): - key = 'input' - x = control_input[i] - if x is not None: - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - out[key].insert(0, x) - - if control_output is not None: - for i in range(len(control_output)): - if i == (len(control_output) - 1): - key = 'middle' - index = 0 - else: - key = 'output' - index = i - x = control_output[i] - if x is not None: - if self.global_average_pooling: - x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - - out[key].append(x) - if control_prev is not None: - for x in ['input', 'middle', 'output']: - o = out[x] - for i in range(len(control_prev[x])): - prev_val = control_prev[x][i] - if i >= len(o): - o.append(prev_val) - elif prev_val is not None: - if o[i] is None: - o[i] = prev_val - else: - o[i] += prev_val - return out - -class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): - super().__init__(device) - self.control_model = control_model - self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) - self.global_average_pooling = global_average_pooling - - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return {} - - output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - - - context = torch.cat(cond['c_crossattn'], 1) - y = cond.get('c_adm', None) - if y is not None: - y = y.to(self.control_model.dtype) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) - return self.control_merge(None, control, control_prev, output_dtype) - - def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) - self.copy_to(c) - return c - - def get_models(self): - out = super().get_models() - out.append(self.control_model_wrapped) - return out - -class ControlLoraOps: - class Linear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = None - self.up = None - self.down = None - self.bias = None - - def forward(self, input): - if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) - else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) - - class Conv2d(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - device=None, - dtype=None - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = False - self.output_padding = 0 - self.groups = groups - self.padding_mode = padding_mode - - self.weight = None - self.bias = None - self.up = None - self.down = None - - - def forward(self, input): - if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) - else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - -class ControlLora(ControlNet): - def __init__(self, control_weights, global_average_pooling=False, device=None): - ControlBase.__init__(self, device) - self.control_weights = control_weights - self.global_average_pooling = global_average_pooling - - def pre_run(self, model, percent_to_timestep_function): - super().pre_run(model, percent_to_timestep_function) - controlnet_config = model.model_config.unet_config.copy() - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() - self.control_model = cldm.ControlNet(**controlnet_config) - dtype = model.get_dtype() - self.control_model.to(dtype) - self.control_model.to(model_management.get_torch_device()) - diffusion_model = model.diffusion_model - sd = diffusion_model.state_dict() - cm = self.control_model.state_dict() - - for k in sd: - weight = sd[k] - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = get_attr(diffusion_model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] - - try: - set_attr(self.control_model, k, weight) - except: - pass - - for k in self.control_weights: - if k not in {"lora_controlnet"}: - set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device())) - - def copy(self): - c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) - self.copy_to(c) - return c - - def cleanup(self): - del self.control_model - self.control_model = None - super().cleanup() - - def get_models(self): - out = ControlBase.get_models(self) - return out - -def load_controlnet(ckpt_path, model=None): - controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) - if "lora_controlnet" in controlnet_data: - return ControlLora(controlnet_data) - - controlnet_config = None - if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) - diffusers_keys = utils.unet_to_diffusers(controlnet_config) - diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" - diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - k_in = "controlnet_down_blocks.{}{}".format(count, s) - k_out = "zero_convs.{}.0{}".format(count, s) - if k_in not in controlnet_data: - loop = False - break - diffusers_keys[k_in] = k_out - count += 1 - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - if count == 0: - k_in = "controlnet_cond_embedding.conv_in{}".format(s) - else: - k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) - k_out = "input_hint_block.{}{}".format(count * 2, s) - if k_in not in controlnet_data: - k_in = "controlnet_cond_embedding.conv_out{}".format(s) - loop = False - diffusers_keys[k_in] = k_out - count += 1 - - new_sd = {} - for k in diffusers_keys: - if k in controlnet_data: - new_sd[diffusers_keys[k]] = controlnet_data.pop(k) - - leftover_keys = controlnet_data.keys() - if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) - controlnet_data = new_sd - - pth_key = 'control_model.zero_convs.0.0.weight' - pth = False - key = 'zero_convs.0.0.weight' - if pth_key in controlnet_data: - pth = True - key = pth_key - prefix = "control_model." - elif key in controlnet_data: - prefix = "" - else: - net = load_t2i_adapter(controlnet_data) - if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) - return net - - if controlnet_config is None: - use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = cldm.ControlNet(**controlnet_config) - - if pth: - if 'difference' in controlnet_data: - if model is not None: - model_management.load_models_gpu([model]) - model_sd = model.model_state_dict() - for x in controlnet_data: - c_m = "control_model." - if x.startswith(c_m): - sd_key = "diffusion_model.{}".format(x[len(c_m):]) - if sd_key in model_sd: - cd = controlnet_data[x] - cd += model_sd[sd_key].type(cd.dtype).to(cd.device) - else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") - - class WeightsLoader(torch.nn.Module): - pass - w = WeightsLoader() - w.control_model = control_model - missing, unexpected = w.load_state_dict(controlnet_data, strict=False) - else: - missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) - - if use_fp16: - control_model = control_model.half() - - global_average_pooling = False - if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - - control = ControlNet(control_model, global_average_pooling=global_average_pooling) - return control - -class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): - super().__init__(device) - self.t2i_model = t2i_model - self.channels_in = channels_in - self.control_input = None - - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return {} - - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.control_input = None - self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) - if self.channels_in == 1 and self.cond_hint.shape[1] > 1: - self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - if self.control_input is None: - self.t2i_model.to(x_noisy.dtype) - self.t2i_model.to(self.device) - self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) - self.t2i_model.cpu() - - control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) - return self.control_merge(control_input, None, control_prev, x_noisy.dtype) - - def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) - self.copy_to(c) - return c - -def load_t2i_adapter(t2i_data): - keys = t2i_data.keys() - if 'adapter' in keys: - t2i_data = t2i_data['adapter'] - keys = t2i_data.keys() - if "body.0.in_conv.weight" in keys: - cin = t2i_data['body.0.in_conv.weight'].shape[1] - model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) - elif 'conv_in.weight' in keys: - cin = t2i_data['conv_in.weight'].shape[1] - channel = t2i_data['conv_in.weight'].shape[0] - ksize = t2i_data['body.0.block2.weight'].shape[2] - use_conv = False - down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) - if len(down_opts) > 0: - use_conv = True - model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv) - else: - return None - model_ad.load_state_dict(t2i_data) - return T2IAdapter(model_ad, cin // 64) - - class StyleModel: def __init__(self, model, device="cpu"): self.model = model @@ -1167,10 +254,10 @@ class StyleModel: def load_style_model(ckpt_path): - model_data = utils.load_torch_file(ckpt_path, safe_load=True) + model_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) keys = model_data.keys() if "style_embedding" in keys: - model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) else: raise Exception("invalid style model {}".format(ckpt_path)) model.load_state_dict(model_data) @@ -1180,14 +267,14 @@ def load_style_model(ckpt_path): def load_clip(ckpt_paths, embedding_directory=None): clip_data = [] for p in ckpt_paths: - clip_data.append(utils.load_torch_file(p, safe_load=True)) + clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) class EmptyClass: pass for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) clip_target = EmptyClass() clip_target.params = {} @@ -1216,11 +303,11 @@ def load_clip(ckpt_paths, embedding_directory=None): return clip def load_gligen(ckpt_path): - data = utils.load_torch_file(ckpt_path, safe_load=True) + data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): #TODO: this function is a mess and should be removed eventually @@ -1256,7 +343,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl pass if state_dict is None: - state_dict = utils.load_torch_file(ckpt_path) + state_dict = comfy.utils.load_torch_file(ckpt_path) class EmptyClass: pass @@ -1300,17 +387,10 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl w.cond_stage_model = clip.cond_stage_model load_clip_weights(w, state_dict) - return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) - -def calculate_parameters(sd, prefix): - params = 0 - for k in sd.keys(): - if k.startswith(prefix): - params += sd[k].nelement() - return params + return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): - sd = utils.load_torch_file(ckpt_path) + sd = comfy.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None @@ -1318,7 +398,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model = None clip_target = None - parameters = calculate_parameters(sd, "model.diffusion_model.") + parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") fp16 = model_management.should_use_fp16(model_params=parameters) class WeightsLoader(torch.nn.Module): @@ -1359,7 +439,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -1368,8 +448,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format - sd = utils.load_torch_file(unet_path) - parameters = calculate_parameters(sd, "") + sd = comfy.utils.load_torch_file(unet_path) + parameters = comfy.utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) @@ -1377,7 +457,7 @@ def load_unet(unet_path): #load unet in diffusers format print("ERROR UNSUPPORTED UNET", unet_path) return None - diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) new_sd = {} for k in diffusers_keys: @@ -1389,9 +469,9 @@ def load_unet(unet_path): #load unet in diffusers format model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) - utils.save_torch_file(sd, output_path, metadata=metadata) + comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c4b5b06f2..59cd6a1ba 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -44,7 +44,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.num_layers = 12 @@ -57,17 +57,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json') config = CLIPTextConfig.from_json_file(textmodel_json_config) self.num_layers = config.num_hidden_layers - with ops.use_comfy_ops(): + with ops.use_comfy_ops(device, dtype): with modeling_utils.no_init_weights(): self.transformer = CLIPTextModel(config) + if dtype is not None: + self.transformer.to(dtype) self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = None self.empty_tokens = [[49406] + [49407] * 76] - self.text_projection = None + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = True if layer == "hidden": assert layer_idx is not None @@ -140,9 +144,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if backup_embeds.weight.dtype != torch.float32: precision_scope = torch.autocast else: - precision_scope = contextlib.nullcontext + precision_scope = lambda a, b: contextlib.nullcontext(a) - with precision_scope(model_management.get_autocast_device(device)): + with precision_scope(model_management.get_autocast_device(device), torch.float32): outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") self.transformer.set_input_embeddings(backup_embeds) @@ -157,13 +161,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): pooled_output = outputs.pooler_output if self.text_projection is not None: - pooled_output = pooled_output.to(self.text_projection.device) @ self.text_projection + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output.float() def encode(self, tokens): return self(tokens) def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + if "text_projection.weight" in sd: + self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index bf816eaa6..272c65a0a 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -1,11 +1,10 @@ from pkg_resources import resource_filename from . import sd1_clip -import torch import os class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=23 @@ -13,7 +12,7 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") if not os.path.exists(textmodel_json_config): textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json') - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] def clip_layer(self, layer_idx): diff --git a/comfy/sd2_clip_config.json b/comfy/sd2_clip_config.json index ace6ef001..85cec832b 100644 --- a/comfy/sd2_clip_config.json +++ b/comfy/sd2_clip_config.json @@ -17,7 +17,7 @@ "num_attention_heads": 16, "num_hidden_layers": 24, "pad_token_id": 1, - "projection_dim": 512, + "projection_dim": 1024, "torch_dtype": "float32", "vocab_size": 49408 } diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 6146bafdf..afbdeadc4 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,23 +3,17 @@ import torch import os class SDXLClipG(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] - self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.layer_norm_hidden_state = False def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return super().load_sd(sd) class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): @@ -42,11 +36,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer): return self.clip_g.untokenize(token_weight_pair) class SDXLClipModel(torch.nn.Module): - def __init__(self, device="cpu"): + def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device) + self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l.layer_norm_hidden_state = False - self.clip_g = SDXLClipG(device=device) + self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): self.clip_l.clip_layer(layer_idx) @@ -70,9 +64,9 @@ class SDXLClipModel(torch.nn.Module): return self.clip_l.load_sd(sd) class SDXLRefinerClipModel(torch.nn.Module): - def __init__(self, device="cpu"): + def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_g = SDXLClipG(device=device) + self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): self.clip_g.clip_layer(layer_idx) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 3647c4cf7..e9a606b1c 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -101,17 +101,30 @@ class ResnetBlock(nn.Module): class Adapter(nn.Module): - def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): + def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): super(Adapter, self).__init__() - self.unshuffle = nn.PixelUnshuffle(8) + self.unshuffle_amount = 8 + resblock_no_downsample = [] + resblock_downsample = [3, 2, 1] + self.xl = xl + if self.xl: + self.unshuffle_amount = 16 + resblock_no_downsample = [1] + resblock_downsample = [2] + + self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount) + self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount) self.channels = channels self.nums_rb = nums_rb self.body = [] for i in range(len(channels)): for j in range(nums_rb): - if (i != 0) and (j == 0): + if (i in resblock_downsample) and (j == 0): self.body.append( ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) + elif (i in resblock_no_downsample) and (j == 0): + self.body.append( + ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) else: self.body.append( ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) @@ -128,8 +141,16 @@ class Adapter(nn.Module): for j in range(self.nums_rb): idx = i * self.nums_rb + j x = self.body[idx](x) - features.append(None) - features.append(None) + if self.xl: + features.append(None) + if i == 0: + features.append(None) + features.append(None) + if i == 2: + features.append(None) + else: + features.append(None) + features.append(None) features.append(x) return features @@ -243,10 +264,14 @@ class extractor(nn.Module): class Adapter_light(nn.Module): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): super(Adapter_light, self).__init__() - self.unshuffle = nn.PixelUnshuffle(8) + self.unshuffle_amount = 8 + self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount) + self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount) self.channels = channels self.nums_rb = nums_rb self.body = [] + self.xl = False + for i in range(len(channels)): if i == 0: self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) diff --git a/comfy/utils.py b/comfy/utils.py index 0fc73a98e..2cb773cee 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -34,6 +34,13 @@ def save_torch_file(sd, ckpt, metadata=None): else: safetensors.torch.save_file(sd, ckpt) +def calculate_parameters(sd, prefix=""): + params = 0 + for k in sd.keys(): + if k.startswith(prefix): + params += sd[k].nelement() + return params + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -232,6 +239,20 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js new file mode 100644 index 000000000..1d935e90a --- /dev/null +++ b/web/extensions/core/groupOptions.js @@ -0,0 +1,167 @@ +import {app} from "../../scripts/app.js"; + +function setNodeMode(node, mode) { + node.mode = mode; + node.graph.change(); +} + +app.registerExtension({ + name: "Comfy.GroupOptions", + setup() { + const orig = LGraphCanvas.prototype.getCanvasMenuOptions; + // graph_mouse + LGraphCanvas.prototype.getCanvasMenuOptions = function () { + const options = orig.apply(this, arguments); + const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); + if (!group) { + return options; + } + + // Group nodes aren't recomputed until the group is moved, this ensures the nodes are up-to-date + group.recomputeInsideNodes(); + const nodesInGroup = group._nodes; + + // No nodes in group, return default options + if (nodesInGroup.length === 0) { + return options; + } else { + // Add a separator between the default options and the group options + options.push(null); + } + + // Check if all nodes are the same mode + let allNodesAreSameMode = true; + for (let i = 1; i < nodesInGroup.length; i++) { + if (nodesInGroup[i].mode !== nodesInGroup[0].mode) { + allNodesAreSameMode = false; + break; + } + } + + // Modes + // 0: Always + // 1: On Event + // 2: Never + // 3: On Trigger + // 4: Bypass + // If all nodes are the same mode, add a menu option to change the mode + if (allNodesAreSameMode) { + const mode = nodesInGroup[0].mode; + switch (mode) { + case 0: + // All nodes are always, option to disable, and bypass + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + break; + case 2: + // All nodes are never, option to enable, and bypass + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + break; + case 4: + // All nodes are bypass, option to enable, and disable + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + break; + default: + // All nodes are On Trigger or On Event(Or other?), option to disable, set to always, or bypass + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + break; + } + } else { + // Nodes are not all the same mode, add a menu option to change the mode to always, never, or bypass + options.push({ + content: "Set Group Nodes to Always", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 0); + } + } + }); + options.push({ + content: "Set Group Nodes to Never", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 2); + } + } + }); + options.push({ + content: "Bypass Group Nodes", + callback: () => { + for (const node of nodesInGroup) { + setNodeMode(node, 4); + } + } + }); + } + + return options + } + } +}); diff --git a/web/extensions/core/uploadImage.js b/web/extensions/core/uploadImage.js index f50473ae3..530c4599e 100644 --- a/web/extensions/core/uploadImage.js +++ b/web/extensions/core/uploadImage.js @@ -5,7 +5,7 @@ import { app } from "../../scripts/app.js"; app.registerExtension({ name: "Comfy.UploadImage", async beforeRegisterNodeDef(nodeType, nodeData, app) { - if (nodeData.name === "LoadImage" || nodeData.name === "LoadImageMask") { + if (nodeData?.input?.required?.image?.[1]?.image_upload === true) { nodeData.input.required.upload = ["IMAGEUPLOAD"]; } }, diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index adf5f26fa..5a4644b13 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -299,11 +299,17 @@ export const ComfyWidgets = { const defaultVal = inputData[1].default || ""; const multiline = !!inputData[1].multiline; + let res; if (multiline) { - return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); + res = addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); } else { - return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; + res = { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; } + + if(inputData[1].dynamicPrompts != undefined) + res.widget.dynamicPrompts = inputData[1].dynamicPrompts; + + return res; }, COMBO(node, inputName, inputData) { const type = inputData[0];