mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 14:02:37 +08:00
Draft PR for SSD loading
This commit is contained in:
parent
723847f6b3
commit
ab987763cb
@ -304,7 +304,8 @@ class UNetModel(nn.Module):
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
upsampling_depth=None,
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
@ -443,8 +444,10 @@ class UNetModel(nn.Module):
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
if isinstance(transformer_depth[level],int):
|
||||
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
|
||||
layers.append(SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
@ -518,7 +521,7 @@ class UNetModel(nn.Module):
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
@ -553,9 +556,11 @@ class UNetModel(nn.Module):
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
if isinstance(transformer_depth[level],int):
|
||||
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user