mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 15:32:35 +08:00
Draft PR for SSD loading
This commit is contained in:
parent
723847f6b3
commit
ab987763cb
@ -304,7 +304,8 @@ class UNetModel(nn.Module):
|
|||||||
resblock_updown=False,
|
resblock_updown=False,
|
||||||
use_new_attention_order=False,
|
use_new_attention_order=False,
|
||||||
use_spatial_transformer=False, # custom transformer support
|
use_spatial_transformer=False, # custom transformer support
|
||||||
transformer_depth=1, # custom transformer support
|
transformer_depth=1, # custom transformer support
|
||||||
|
upsampling_depth=None,
|
||||||
context_dim=None, # custom transformer support
|
context_dim=None, # custom transformer support
|
||||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
legacy=True,
|
legacy=True,
|
||||||
@ -443,8 +444,10 @@ class UNetModel(nn.Module):
|
|||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
|
if isinstance(transformer_depth[level],int):
|
||||||
|
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
|
||||||
layers.append(SpatialTransformer(
|
layers.append(SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
@ -518,7 +521,7 @@ class UNetModel(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||||
for i in range(self.num_res_blocks[level] + 1):
|
for i in range(self.num_res_blocks[level] + 1):
|
||||||
@ -553,9 +556,11 @@ class UNetModel(nn.Module):
|
|||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
|
|
||||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
|
if isinstance(transformer_depth[level],int):
|
||||||
|
transformer_depth[level]=transformer_depth[level]*num_attention_blocks[level]
|
||||||
layers.append(
|
layers.append(
|
||||||
SpatialTransformer(
|
SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level][i], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user