Fixed bug in openaimodel.py

This commit is contained in:
Vishnu V Jaddipal 2023-10-27 13:21:56 +00:00
parent ed9ac9205a
commit 25a9a29ab4
2 changed files with 46 additions and 31 deletions

View File

@ -492,36 +492,51 @@ class UNetModel(nn.Module):
if legacy: if legacy:
#num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( if resnet_only_mid_block is False:
ResBlock( self.middle_block = TimestepEmbedSequential(
ch, ResBlock(
time_embed_dim, ch,
dropout, time_embed_dim,
dims=dims, dropout,
use_checkpoint=use_checkpoint, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_checkpoint=use_checkpoint,
dtype=self.dtype, use_scale_shift_norm=use_scale_shift_norm,
device=device, dtype=self.dtype,
operations=operations device=device,
), operations=operations
if (resnet_only_mid_block is False): ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
time_embed_dim, time_embed_dim,
dropout, dropout,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations operations=operations
), ),
) )
else:
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
self._feature_size += ch self._feature_size += ch
transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth transformer_depth = upsampling_depth if upsampling_depth is not None else transformer_depth
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])

View File

@ -193,7 +193,7 @@ class SSD1B(SDXL):
unet_config = { unet_config = {
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"transformer_depth": [0, 2, 4], # SDXL is [0, 2, 10] here "transformer_depth": [0, 2, 4],
"upsampling_depth": [0,[2,1,1],[4,4,10]], "upsampling_depth": [0,[2,1,1],[4,4,10]],
"resnet_only_mid_block": True, "resnet_only_mid_block": True,
"context_dim": 2048, "context_dim": 2048,