From a58a848e24fcaeb2766dca7a9a0ffa3c5ab6cce7 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Fri, 27 Oct 2023 12:55:16 +0000 Subject: [PATCH] Add optional resnet-only midblock --- .../modules/diffusionmodules/openaimodel.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 657230978..519651e03 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -306,6 +306,7 @@ class UNetModel(nn.Module): use_spatial_transformer=False, # custom transformer support transformer_depth=1, # custom transformer support upsampling_depth=None, + resnet_only_mid_block=False, context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, @@ -447,7 +448,7 @@ class UNetModel(nn.Module): if isinstance(transformer_depth[level],int): transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level] layers.append(SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim, + ch, num_heads, dim_head, depth=transformer_depth[level][nr], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) @@ -503,22 +504,23 @@ class UNetModel(nn.Module): device=device, operations=operations ), - SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - ), + if (resnet_only_mid_block is False): + SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + ), ) self._feature_size += ch transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth