mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 07:22:36 +08:00
Fixed bug in openaimodel.py
This commit is contained in:
parent
ed9ac9205a
commit
25a9a29ab4
@ -492,36 +492,51 @@ class UNetModel(nn.Module):
|
|||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
if resnet_only_mid_block is False:
|
||||||
ResBlock(
|
self.middle_block = TimestepEmbedSequential(
|
||||||
ch,
|
ResBlock(
|
||||||
time_embed_dim,
|
ch,
|
||||||
dropout,
|
time_embed_dim,
|
||||||
dims=dims,
|
dropout,
|
||||||
use_checkpoint=use_checkpoint,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_checkpoint=use_checkpoint,
|
||||||
dtype=self.dtype,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
device=device,
|
dtype=self.dtype,
|
||||||
operations=operations
|
device=device,
|
||||||
),
|
operations=operations
|
||||||
if (resnet_only_mid_block is False):
|
),
|
||||||
SpatialTransformer( # always uses a self-attn
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
dropout,
|
dropout,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
|
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
|||||||
@ -193,7 +193,7 @@ class SSD1B(SDXL):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"transformer_depth": [0, 2, 4], # SDXL is [0, 2, 10] here
|
"transformer_depth": [0, 2, 4],
|
||||||
"upsampling_depth": [0,[2,1,1],[4,4,10]],
|
"upsampling_depth": [0,[2,1,1],[4,4,10]],
|
||||||
"resnet_only_mid_block": True,
|
"resnet_only_mid_block": True,
|
||||||
"context_dim": 2048,
|
"context_dim": 2048,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user