From 37a86e4618098ef1e0d692d0953f072388cbc673 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Feb 2024 03:57:41 -0500 Subject: [PATCH 1/3] Remove duplicate text_projection key from some saved models. --- comfy/diffusers_convert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 8e3ca94e5..eb561933a 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -240,9 +240,9 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): text_proj = "transformer.text_projection.weight" if k.endswith(text_proj): new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous() - - relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) - new_state_dict[relabelled_key] = v + else: + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v for k_pre, tensors in capture_qkv_weight.items(): if None in tensors: From b3e97fc7141681b1fa6da3ee6701c0f9a31d38f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Feb 2024 11:55:06 -0500 Subject: [PATCH 2/3] Koala 700M and 1B support. Use the UNET Loader node to load the unet file to use them. --- .../modules/diffusionmodules/openaimodel.py | 48 ++++++++++--------- comfy/model_detection.py | 23 +++++++-- comfy/supported_models.py | 22 ++++++++- 3 files changed, 66 insertions(+), 27 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 998afd977..c54770255 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -708,27 +708,30 @@ class UNetModel(nn.Module): device=device, operations=operations )] - if transformer_depth_middle >= 0: - mid_block += [get_attention_layer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint - ), - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_channels=None, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - )] - self.middle_block = TimestepEmbedSequential(*mid_block) + + self.middle_block = None + if transformer_depth_middle >= -1: + if transformer_depth_middle >= 0: + mid_block += [get_attention_layer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -858,7 +861,8 @@ class UNetModel(nn.Module): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + if self.middle_block is not None: + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8fca6d8c8..07ee85708 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix): channel_mult.append(last_channel_mult) if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - else: + elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = -1 + else: + transformer_depth_middle = -2 unet_config["in_channels"] = in_channels unet_config["out_channels"] = out_channels @@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): down_blocks = count_blocks(state_dict, "down_blocks.{}") for i in range(down_blocks): attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') + res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}') for ab in range(attn_blocks): transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_depth.append(transformer_count) @@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): attn_res *= 2 if attn_blocks == 0: - transformer_depth.append(0) - transformer_depth.append(0) + for i in range(res_blocks): + transformer_depth.append(0) match["transformer_depth"] = transformer_depth @@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] + KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] for unet_config in supported_models: matches = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 74908216c..375821032 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -234,6 +234,26 @@ class Segmind_Vega(SDXL): "use_temporal_attention": False, } +class KOALA_700M(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 5], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + +class KOALA_1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 6], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + class SVD_img2vid(supported_models_base.BASE): unet_config = { "model_channels": 320, @@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C): return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] models += [SVD_img2vid] From cb7c3a2921cfc0805be0229b4634e1143d60e6fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Feb 2024 13:09:43 -0500 Subject: [PATCH 3/3] Allow image_only_indicator to be None. --- .../modules/diffusionmodules/openaimodel.py | 3 +-- comfy/ldm/modules/diffusionmodules/util.py | 22 ++++++++++--------- comfy/model_base.py | 1 - 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index c54770255..cf89ae017 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -484,7 +484,6 @@ class UNetModel(nn.Module): self.predict_codebook_ids = n_embed is not None self.default_num_video_frames = None - self.default_image_only_indicator = None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -830,7 +829,7 @@ class UNetModel(nn.Module): transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + image_only_indicator = kwargs.get("image_only_indicator", None) time_context = kwargs.get("time_context", None) assert (y is not None) == ( diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 5a6aa7d77..ce14ad5e1 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -46,23 +46,25 @@ class AlphaBlender(nn.Module): else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") - def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor.to(image_only_indicator.device) + alpha = self.mix_factor.to(device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) + alpha = torch.sigmoid(self.mix_factor.to(device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": - assert image_only_indicator is not None, "need image_only_indicator ..." - alpha = torch.where( - image_only_indicator.bool(), - torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), - ) + if image_only_indicator is None: + alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1") + else: + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), + ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) @@ -76,7 +78,7 @@ class AlphaBlender(nn.Module): x_temporal, image_only_indicator=None, ) -> torch.Tensor: - alpha = self.get_alpha(image_only_indicator) + alpha = self.get_alpha(image_only_indicator, x_spatial.device) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal diff --git a/comfy/model_base.py b/comfy/model_base.py index 170b1fd44..69117dd0e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -372,7 +372,6 @@ class SVD_img2vid(BaseModel): if "time_conditioning" in kwargs: out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) - out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out