mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
Add optional resnet-only midblock
This commit is contained in:
parent
ab987763cb
commit
a58a848e24
@ -306,6 +306,7 @@ class UNetModel(nn.Module):
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
upsampling_depth=None,
|
||||
resnet_only_mid_block=False,
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
@ -447,7 +448,7 @@ class UNetModel(nn.Module):
|
||||
if isinstance(transformer_depth[level],int):
|
||||
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
|
||||
layers.append(SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level][nr], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
@ -503,22 +504,23 @@ class UNetModel(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
if (resnet_only_mid_block is False):
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
|
||||
|
||||
Loading…
Reference in New Issue
Block a user