diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index bf58a4045..657230978 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -304,7 +304,8 @@ class UNetModel(nn.Module): resblock_updown=False, use_new_attention_order=False, use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support + transformer_depth=1, # custom transformer support + upsampling_depth=None, context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, @@ -443,8 +444,10 @@ class UNetModel(nn.Module): disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + if isinstance(transformer_depth[level],int): + transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level] layers.append(SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) @@ -518,7 +521,7 @@ class UNetModel(nn.Module): ), ) self._feature_size += ch - + transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(self.num_res_blocks[level] + 1): @@ -553,9 +556,11 @@ class UNetModel(nn.Module): disabled_sa = False if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + if isinstance(transformer_depth[level],int): + transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level] layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations )