Add optional resnet-only midblock

This commit is contained in:
Vishnu V Jaddipal 2023-10-27 12:55:16 +00:00
parent ab987763cb
commit a58a848e24

View File

@ -306,6 +306,7 @@ class UNetModel(nn.Module):
use_spatial_transformer=False, # custom transformer support use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support transformer_depth=1, # custom transformer support
upsampling_depth=None, upsampling_depth=None,
resnet_only_mid_block=False,
context_dim=None, # custom transformer support context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True, legacy=True,
@ -447,7 +448,7 @@ class UNetModel(nn.Module):
if isinstance(transformer_depth[level],int): if isinstance(transformer_depth[level],int):
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level] transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
layers.append(SpatialTransformer( layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level][nr], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
@ -503,22 +504,23 @@ class UNetModel(nn.Module):
device=device, device=device,
operations=operations operations=operations
), ),
SpatialTransformer( # always uses a self-attn if (resnet_only_mid_block is False):
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, SpatialTransformer( # always uses a self-attn
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
), use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
ResBlock( ),
ch, ResBlock(
time_embed_dim, ch,
dropout, time_embed_dim,
dims=dims, dropout,
use_checkpoint=use_checkpoint, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_checkpoint=use_checkpoint,
dtype=self.dtype, use_scale_shift_norm=use_scale_shift_norm,
device=device, dtype=self.dtype,
operations=operations device=device,
), operations=operations
),
) )
self._feature_size += ch self._feature_size += ch
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth