Draft PR for SSD loading

This commit is contained in:
Vishnu V Jaddipal 2023-10-26 14:06:19 +00:00
parent 723847f6b3
commit ab987763cb

View File

@ -304,7 +304,8 @@ class UNetModel(nn.Module):
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
transformer_depth=1, # custom transformer support
upsampling_depth=None,
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
@ -443,8 +444,10 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
if isinstance(transformer_depth[level],int):
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
@ -518,7 +521,7 @@ class UNetModel(nn.Module):
),
)
self._feature_size += ch
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
@ -553,9 +556,11 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
if isinstance(transformer_depth[level],int):
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)