From af365e4dd152b23cd6cf993ddf9ed7c7e7088b39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 03:12:18 -0500 Subject: [PATCH 1/6] All the unet ops with weights are now handled by comfy.ops --- comfy/controlnet.py | 10 ++++++++++ comfy/ldm/modules/attention.py | 18 ++++-------------- .../modules/diffusionmodules/openaimodel.py | 13 ++++++------- comfy/ops.py | 8 ++++++++ 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 433381df6..6dd99afdc 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -5,6 +5,7 @@ import comfy.utils import comfy.model_management import comfy.model_detection import comfy.model_patcher +import comfy.ops import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -248,6 +249,15 @@ class ControlLoraOps: else: raise ValueError(f"unsupported dimensions: {dims}") + class Conv3d(comfy.ops.Conv3d): + pass + + class GroupNorm(comfy.ops.GroupNorm): + pass + + class LayerNorm(comfy.ops.LayerNorm): + pass + class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, device=None): diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f68452382..c2b85a691 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -83,16 +83,6 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) @@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module): self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head @@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] * depth self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if not use_linear: self.proj_in = operations.Conv2d(in_channels, inner_dim, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 48264892c..855c3d1f4 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -177,7 +177,7 @@ class ResBlock(TimestepBlock): padding = kernel_size // 2 self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels, dtype=dtype, device=device), + operations.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) @@ -206,12 +206,11 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), + operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) - ), + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) + , ) if self.out_channels == channels: @@ -810,13 +809,13 @@ class UNetModel(nn.Module): self._feature_size += ch self.out = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) diff --git a/comfy/ops.py b/comfy/ops.py index 0bfb698aa..deb849d63 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d): def reset_parameters(self): return None +class GroupNorm(torch.nn.GroupNorm): + def reset_parameters(self): + return None + +class LayerNorm(torch.nn.LayerNorm): + def reset_parameters(self): + return None + def conv_nd(dims, *args, **kwargs): if dims == 2: return Conv2d(*args, **kwargs) From 31b0f6f3d8034371e95024d6bba5c193db79bd9d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 11:10:00 -0500 Subject: [PATCH 2/6] UNET weights can now be stored in fp8. --fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch. --- comfy/cldm/cldm.py | 4 ++-- comfy/cli_args.py | 5 ++++- comfy/controlnet.py | 16 ++++++++++++---- .../ldm/modules/diffusionmodules/openaimodel.py | 4 ++-- comfy/model_base.py | 13 ++++++++++++- comfy/model_management.py | 15 +++++++++++++++ 6 files changed, 47 insertions(+), 10 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 76a525b37..bbe5891e6 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -283,7 +283,7 @@ class ControlNet(nn.Module): return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -295,7 +295,7 @@ class ControlNet(nn.Module): assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for module, zero_conv in zip(self.input_blocks, self.zero_convs): if guided_hint is not None: h = module(h, emb, context) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 72fce1087..58d034802 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -55,7 +55,10 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") -parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group = parser.add_mutually_exclusive_group() +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6dd99afdc..5921e6b1d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,6 +1,7 @@ import torch import math import os +import contextlib import comfy.utils import comfy.model_management import comfy.model_detection @@ -147,24 +148,31 @@ class ControlNet(ControlBase): else: return None + dtype = self.control_model.dtype + if comfy.model_management.supports_dtype(self.device, dtype): + precision_scope = lambda a: contextlib.nullcontext(a) + else: + precision_scope = torch.autocast + dtype = torch.float32 + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] y = cond.get('y', None) if y is not None: - y = y.to(self.control_model.dtype) + y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) + with precision_scope(comfy.model_management.get_autocast_device(self.device)): + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 855c3d1f4..12efd833c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -841,14 +841,14 @@ class UNetModel(nn.Module): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) diff --git a/comfy/model_base.py b/comfy/model_base.py index 253ea6667..5bfcc391d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -5,6 +5,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management import comfy.conds from enum import Enum +import contextlib from . import utils class ModelType(Enum): @@ -61,6 +62,13 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() + + if comfy.model_management.supports_dtype(xc.device, dtype): + precision_scope = lambda a: contextlib.nullcontext(a) + else: + precision_scope = torch.autocast + dtype = torch.float32 + xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() context = context.to(dtype) @@ -70,7 +78,10 @@ class BaseModel(torch.nn.Module): if hasattr(extra, "to"): extra = extra.to(dtype) extra_conds[o] = extra - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + + with precision_scope(comfy.model_management.get_autocast_device(xc.device)): + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index d4acd8950..18d15f9d0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -459,6 +459,10 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp8_e4m3fn_unet: + return torch.float8_e4m3fn + if args.fp8_e5m2_unet: + return torch.float8_e5m2 if should_use_fp16(device=device, model_params=model_params): return torch.float16 return torch.float32 @@ -515,6 +519,17 @@ def get_autocast_device(dev): return dev.type return "cuda" +def supports_dtype(device, dtype): #TODO + if dtype == torch.float32: + return True + if torch.device("cpu") == device: + return False + if dtype == torch.float16: + return True + if dtype == torch.bfloat16: + return True + return False + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: From ca82ade7652c80727b402f51a115feb5df4ad27a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 11:52:06 -0500 Subject: [PATCH 3/6] Use .itemsize to get dtype size for fp8. --- comfy/model_management.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 18d15f9d0..94d596969 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -430,6 +430,13 @@ def dtype_size(dtype): dtype_size = 4 if dtype == torch.float16 or dtype == torch.bfloat16: dtype_size = 2 + elif dtype == torch.float32: + dtype_size = 4 + else: + try: + dtype_size = dtype.itemsize + except: #Old pytorch doesn't have .itemsize + pass return dtype_size def unet_offload_device(): From be3468ddd5db871e3943003e0fd7a2219c7d02e6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 12:49:00 -0500 Subject: [PATCH 4/6] Less useless downcasting. --- comfy/sd1_clip.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 58acb97fc..4e9f6bffe 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.inner_name = inner_name if dtype is not None: - self.transformer.to(dtype) inner_model = getattr(self.transformer, self.inner_name) if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) + embeddings_bak = inner_model.embeddings.to(torch.float32) + inner_model.embeddings = None + self.transformer.to(dtype) + inner_model.embeddings = embeddings_bak else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True) + self.transformer.to(dtype) + self.transformer.set_input_embeddings(previous_inputs) self.max_length = max_length if freeze: From 26b1c0a77150be2253f88e4cd106a11112d96d59 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 13:47:41 -0500 Subject: [PATCH 5/6] Fix control lora on fp8. --- comfy/controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 5921e6b1d..6d37aa74f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -208,7 +208,7 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) else: return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) @@ -247,7 +247,7 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) else: return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) From 9b655d4fd72903d33af101177b0cb9576c5babd3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 21:55:19 -0500 Subject: [PATCH 6/6] Fix memory issue with control loras. --- comfy/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 034db97ee..bcbed3343 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -101,7 +101,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative samples = samples.cpu() cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): @@ -113,6 +113,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples