From a82fae2375744a73d63268bc4e167649e3f026e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 10 Jun 2024 16:00:03 -0400 Subject: [PATCH 01/15] Fix bug with cosxl edit model. --- comfy/conds.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/conds.py b/comfy/conds.py index 23fa48872..660690af8 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -29,7 +29,12 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): def process_cond(self, batch_size, device, area, **kwargs): - data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + data = self.cond + if area is not None: + dims = len(area) // 2 + for i in range(dims): + data = data.narrow(i + 2, area[i + dims], area[i]) + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) From 4134564dc15c5eb40a61e2da4c493cf6786e67d1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 06:26:13 -0400 Subject: [PATCH 02/15] Require safetensors library to be at least 0.4.2 for fp8 support. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 906b96eda..8f681f8fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ torchsde torchvision einops transformers>=4.25.1 -safetensors>=0.3.0 +safetensors>=0.4.2 aiohttp pyyaml Pillow From 73ce178021338e2cea419a00ae61ec0a6630ef19 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Tue, 11 Jun 2024 18:30:25 +0800 Subject: [PATCH 03/15] Remove redundancy in mmdit.py (#3685) --- comfy/ldm/modules/diffusionmodules/mmdit.py | 61 --------------------- 1 file changed, 61 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 5e7afc8d3..be40ab940 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -835,72 +835,11 @@ class MMDiT(nn.Module): ) self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) - # self.initialize_weights() if compile_core: assert False self.forward_core_with_concat = torch.compile(self.forward_core_with_concat) - def initialize_weights(self): - # TODO: Init context_embedder? - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize (and freeze) pos_embed by sin-cos embedding - if self.pos_embed is not None: - pos_embed_grid_size = ( - int(self.x_embedder.num_patches**0.5) - if self.pos_embed_max_size is None - else self.pos_embed_max_size - ) - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.x_embedder.num_patches**0.5), - pos_embed_grid_size, - scaling_factor=self.pos_embed_scaling_factor, - offset=self.pos_embed_offset, - ) - - - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.pos_embed.shape[-2]**0.5), - scaling_factor=self.pos_embed_scaling_factor, - ) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - if hasattr(self, "y_embedder"): - nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - for block in self.joint_blocks: - nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) - nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, hw, device=None): p = self.x_embedder.patch_size[0] h, w = hw From 9424522ead16a36199c50217ae09093ff8e3223d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 07:20:26 -0400 Subject: [PATCH 04/15] Reuse code. --- comfy/model_base.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a26b442b1..28458bbab 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -573,13 +573,8 @@ class SD3(BaseModel): return kwargs["pooled_output"] def extra_conds(self, **kwargs): - out = {} - adm = self.encode_adm(**kwargs) - if adm is not None: - out['y'] = comfy.conds.CONDRegular(adm) - + out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - From 1c34d338d7cc11513023080cd0adedbf9f997356 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 07:37:22 -0400 Subject: [PATCH 05/15] Update EmptySD3LatentImage to use 1024 resolution by default. --- comfy_extras/nodes_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 80b8644a4..d0303aec5 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -27,8 +27,8 @@ class EmptySD3LatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" From 5889b7ca0ad7b3bf036999330b0c5371ce0da0f3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 13:14:43 -0400 Subject: [PATCH 06/15] Support multiple text encoder configurations on SD3. --- comfy/sd.py | 2 +- comfy/sd3_clip.py | 89 ++++++++++++++++++++++++++++----------- comfy/supported_models.py | 35 ++++++++++----- 3 files changed, 91 insertions(+), 35 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index cb147fa46..11764077a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -482,7 +482,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o vae = VAE(sd=vae_sd) if output_clip: - clip_target = model_config.clip_target() + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index bbbf6affd..595381fc0 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -5,6 +5,7 @@ import comfy.t5 import torch import os import comfy.model_management +import logging class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): @@ -43,42 +44,82 @@ class SD3Tokenizer: return self.clip_g.untokenize(token_weight_pair) class SD3ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None): + def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) - self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) - self.t5xxl = T5XXLModel(device=device, dtype=dtype) + if clip_l: + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) + else: + self.clip_g = None + + if t5: + self.t5xxl = T5XXLModel(device=device, dtype=dtype) + else: + self.t5xxl = None + + logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5)) def set_clip_options(self, options): - self.clip_l.set_clip_options(options) - self.clip_g.set_clip_options(options) - self.t5xxl.set_clip_options(options) + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) def reset_clip_options(self): - self.clip_g.reset_clip_options() - self.clip_l.reset_clip_options() - self.t5xxl.reset_clip_options() + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_g = token_weight_pairs["g"] token_weight_pars_t5 = token_weight_pairs["t5xxl"] lg_out = None - if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: - l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - out = lg_out - pooled = torch.cat((l_pooled, g_pooled), dim=-1) - else: - pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device()) + pooled = None + out = None - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) - if lg_out is not None: - out = torch.cat([lg_out, t5_out], dim=-2) - else: - out = t5_out + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + if lg_out is not None: + lg_out = torch.cat([lg_out, g_out], dim=-1) + else: + lg_out = torch.nn.functional.pad(g_out, (768, 0)) + else: + g_out = None + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + if lg_out is not None: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + out = lg_out + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + if lg_out is not None: + out = torch.cat([lg_out, t5_out], dim=-2) + else: + out = t5_out + + if out is None: + out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) return out, pooled diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6bb76c96f..481ecaa62 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE): replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) class SD20(supported_models_base.BASE): @@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE): state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) class SD21UnclipL(SD20): @@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE): state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) class SDXL(supported_models_base.BASE): @@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE): state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) class SSD1B(SDXL): @@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE): out = model_base.SVD_img2vid(self, device=device) return out - def clip_target(self): + def clip_target(self, state_dict={}): return None class SV3D_u(SVD_img2vid): @@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE): out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) return out - def clip_target(self): + def clip_target(self, state_dict={}): return None class SD_X4Upscaler(SD20): @@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE): out = model_base.StableCascade_C(self, device=device) return out - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) class Stable_Cascade_B(Stable_Cascade_C): @@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 - text_encoder_key_prefix = ["text_encoders."] #TODO? + text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): out = model_base.SD3(self, device=device) return out - def clip_target(self): - return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO? + def clip_target(self, state_dict={}): + clip_l = False + clip_g = False + t5 = False + pref = self.text_encoder_key_prefix[0] + if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + clip_l = True + if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + clip_g = True + if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict: + t5 = True + + class SD3ClipModel(sd3_clip.SD3ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype) + + return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] From 0e49211a110fff099ebafa927ee7b9416ff9feaa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 17:03:26 -0400 Subject: [PATCH 07/15] Load the SD3 T5xxl model in the same dtype stored in the checkpoint. --- comfy/model_management.py | 17 +++++++++++++++++ comfy/sd.py | 8 +++++++- comfy/sd1_clip.py | 4 ++++ comfy/sd3_clip.py | 18 +++++++++++++++--- comfy/sdxl_clip.py | 1 + comfy/supported_models.py | 7 +++++-- 6 files changed, 49 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 57aa8bca2..dbd0dbac6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO return True return False +def supports_cast(device, dtype): #TODO + if dtype == torch.float32: + return True + if dtype == torch.float16: + return True + if is_device_mps(device): + return False + if directml_enabled: #TODO: test this + return False + if dtype == torch.bfloat16: + return True + if dtype == torch.float8_e4m3fn: + return True + if dtype == torch.float8_e5m2: + return True + return False + def device_supports_non_blocking(device): if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking diff --git a/comfy/sd.py b/comfy/sd.py index 11764077a..a7b4dbcf2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -98,13 +98,19 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = offload_device - params['dtype'] = model_management.text_encoder_dtype(load_device) + dtype = model_management.text_encoder_dtype(load_device) + params['dtype'] = dtype self.cond_stage_model = clip(**(params)) + for dt in self.cond_stage_model.dtypes: + if not model_management.supports_cast(load_device, dt): + load_device = offload_device + self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None + logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) def clone(self): n = CLIP(no_init=True) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 2729f14d8..911af0a7e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -511,6 +511,10 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + self.dtypes = set() + if dtype is not None: + self.dtypes.add(dtype) + def set_clip_options(self, options): getattr(self, self.clip).set_clip_options(options) diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index 595381fc0..cbbbe53dd 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -44,24 +44,36 @@ class SD3Tokenizer: return self.clip_g.untokenize(token_weight_pair) class SD3ClipModel(torch.nn.Module): - def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None): + def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): super().__init__() + self.dtypes = set() if clip_l: self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) + self.dtypes.add(dtype) else: self.clip_l = None if clip_g: self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) + self.dtypes.add(dtype) else: self.clip_g = None if t5: - self.t5xxl = T5XXLModel(device=device, dtype=dtype) + if dtype_t5 is None: + dtype_t5 = dtype + elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype): + dtype_t5 = dtype + + if not comfy.model_management.supports_cast(device, dtype_t5): + dtype_t5 = dtype + + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) + self.dtypes.add(dtype_t5) else: self.t5xxl = None - logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5)) + logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5)) def set_clip_options(self, options): if self.clip_l is not None: diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index e62d1ed86..1257cba1e 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module): super().__init__() self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) + self.dtypes = set([dtype]) def set_clip_options(self, options): self.clip_l.set_clip_options(options) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 481ecaa62..a49df7a35 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE): clip_l = False clip_g = False t5 = False + dtype_t5 = None pref = self.text_encoder_key_prefix[0] if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: clip_l = True if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: clip_g = True - if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict: + t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + if t5_key in state_dict: t5 = True + dtype_t5 = state_dict[t5_key].dtype class SD3ClipModel(sd3_clip.SD3ClipModel): def __init__(self, device="cpu", dtype=None): - super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype) + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel) From 69c8d6d8a60ac6355d02a8a003d55e304fcc702d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 23:27:39 -0400 Subject: [PATCH 08/15] Single and dual clip loader nodes support SD3. You can use the CLIPLoader to use the t5xxl only or the DualCLIPLoader to use CLIP-L and CLIP-G only for sd3. --- comfy/sd.py | 13 +++++++++++-- comfy/sd3_clip.py | 6 ++++++ comfy/supported_models.py | 6 +----- nodes.py | 17 +++++++++++++---- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index a7b4dbcf2..3fd9e0e98 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -370,6 +370,7 @@ def load_style_model(ckpt_path): class CLIPType(Enum): STABLE_DIFFUSION = 1 STABLE_CASCADE = 2 + SD3 = 3 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] @@ -399,12 +400,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: + dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype + clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) + clip_target.tokenizer = sd3_clip.SD3Tokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: - clip_target.clip = sdxl_clip.SDXLClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.SD3: + clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) + clip_target.tokenizer = sd3_clip.SD3Tokenizer + else: + clip_target.clip = sdxl_clip.SDXLClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif len(clip_data) == 3: clip_target.clip = sd3_clip.SD3ClipModel clip_target.tokenizer = sd3_clip.SD3Tokenizer diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index cbbbe53dd..0713eb285 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -142,3 +142,9 @@ class SD3ClipModel(torch.nn.Module): return self.clip_l.load_sd(sd) else: return self.t5xxl.load_sd(sd) + +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): + class SD3ClipModel_(SD3ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) + return SD3ClipModel_ diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a49df7a35..c8ddf3e2c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -522,11 +522,7 @@ class SD3(supported_models_base.BASE): t5 = True dtype_t5 = state_dict[t5_key].dtype - class SD3ClipModel(sd3_clip.SD3ClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) - - return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel) + return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] diff --git a/nodes.py b/nodes.py index ef1f85613..6fbeb377e 100644 --- a/nodes.py +++ b/nodes.py @@ -818,7 +818,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), - "type": (["stable_diffusion", "stable_cascade"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -829,6 +829,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION if type == "stable_cascade": clip_type = comfy.sd.CLIPType.STABLE_CASCADE + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 clip_path = folder_paths.get_full_path("clip", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) @@ -837,17 +839,24 @@ class CLIPLoader: class DualCLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), + return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), + "clip_name2": (folder_paths.get_filename_list("clip"), ), + "type": (["sdxl", "sd3"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name1, clip_name2): + def load_clip(self, clip_name1, clip_name2, type): clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path2 = folder_paths.get_full_path("clip", clip_name2) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) + if type == "sdxl": + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class CLIPVisionLoader: From 694e0b48e0f6d55ddfcbe44a5d2818774241077b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 00:49:00 -0400 Subject: [PATCH 09/15] SD3 better memory usage estimation. --- comfy/model_base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 28458bbab..7ef034408 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -578,3 +578,15 @@ class SD3(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + + def memory_required(self, input_shape): + if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): + dtype = self.get_dtype() + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + #TODO: this probably needs to be tweaked + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) + else: + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * 0.3) * (1024 * 1024) From 32be358213bf14a342073eca6fc45623d07e6267 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 01:02:07 -0400 Subject: [PATCH 10/15] Save SD3 modelspec.architecture in CheckpointSave node. --- comfy_extras/nodes_model_merging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index bb15112f4..b0d149c60 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -183,6 +183,8 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" elif isinstance(model.model, comfy.model_base.SVD_img2vid): metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1" + elif isinstance(model.model, comfy.model_base.SD3): + metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants else: enable_modelspec = False From 1ddf512fdc69bf8dfb51eb858d3e5ba069570791 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 01:07:58 -0400 Subject: [PATCH 11/15] Don't auto convert clip and vae weights to fp16 when saving checkpoint. --- comfy/model_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7ef034408..21f884ba2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -207,9 +207,6 @@ class BaseModel(torch.nn.Module): unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.get_dtype() == torch.float16: - extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) From c8b5e08dc39171babb5d43f160cc04271591743e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 02:24:39 -0400 Subject: [PATCH 12/15] Default shift value on SD3 is 3.0 --- comfy_extras/nodes_model_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 64002a8db..9bcd3c397 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -136,7 +136,7 @@ class ModelSamplingSD3: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01}), + "shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}), }} RETURN_TYPES = ("MODEL",) From 321e509e0a8a143095d29969ef9f1b32c9a93c9f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 09:48:27 -0400 Subject: [PATCH 13/15] Add link to SD3 example page to README. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index de0c062ae..a40dd07dd 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) and [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) From 0eaa34ec5b22fd305159a78ea92b2ef00105ab18 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 10:32:34 -0400 Subject: [PATCH 14/15] Fix regular empty latent image not working with SD3 and custom sampler. --- comfy_extras/nodes_custom_sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 45ef8cf40..69f1b9418 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -380,7 +380,10 @@ class SamplerCustom: def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent = latent.copy() latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent["samples"] = latent_image + if not add_noise: noise = Noise_EmptyNoise().generate_noise(latent) else: @@ -539,7 +542,9 @@ class SamplerCustomAdvanced: def sample(self, noise, guider, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent = latent.copy() latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent["samples"] = latent_image noise_mask = None if "noise_mask" in latent: From 605e64f6d3da44235498bf9103d7aab1c95ef211 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Jun 2024 10:39:33 -0400 Subject: [PATCH 15/15] Fix lowvram issue. --- comfy/ldm/modules/diffusionmodules/mmdit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index be40ab940..0cb6bd312 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -934,7 +934,7 @@ class MMDiT(nn.Module): context = self.context_processor(context) hw = x.shape[-2:] - x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype) + x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D)