From 25a9a29ab4df2b49238373e2eade739658adbda7 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:21:56 +0000 Subject: [PATCH] Fixed bug in openaimodel.py --- .../modules/diffusionmodules/openaimodel.py | 75 +++++++++++-------- comfy/supported_models.py | 2 +- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 519651e03..284a34fed 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -492,36 +492,51 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - ), - if (resnet_only_mid_block is False): - SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - ), - ) + if resnet_only_mid_block is False: + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + ), + SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + ), + ) + else: + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + ), + ) + self._feature_size += ch transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth self.output_blocks = nn.ModuleList([]) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2e1497025..fa13ac999 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -193,7 +193,7 @@ class SSD1B(SDXL): unet_config = { "model_channels": 320, "use_linear_in_transformer": True, - "transformer_depth": [0, 2, 4], # SDXL is [0, 2, 10] here + "transformer_depth": [0, 2, 4], "upsampling_depth": [0,[2,1,1],[4,4,10]], "resnet_only_mid_block": True, "context_dim": 2048,