resblock fix

This commit is contained in:
Yousef Rafat 2025-11-17 06:50:54 +02:00
parent 61b1efdaf0
commit 3f71760913
2 changed files with 5 additions and 5 deletions

View File

@ -357,6 +357,7 @@ class UNetDown(nn.Module):
channels=hidden_channels,
emb_channels=emb_channels,
out_channels=out_channels,
use_scale_shift_norm = True,
dropout=dropout,
**factory_kwargs
))
@ -365,6 +366,7 @@ class UNetDown(nn.Module):
self.model.append(ResBlock(
channels=hidden_channels,
emb_channels=emb_channels,
use_scale_shift_norm = True,
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
dropout=dropout,
down=True,
@ -401,6 +403,7 @@ class UNetUp(nn.Module):
channels=in_channels,
emb_channels=emb_channels,
out_channels=hidden_channels,
use_scale_shift_norm = True,
dropout=dropout,
**factory_kwargs
))
@ -410,6 +413,7 @@ class UNetUp(nn.Module):
channels=in_channels if i == 0 else hidden_channels,
emb_channels=emb_channels,
out_channels=hidden_channels,
use_scale_shift_norm = True,
dropout=dropout,
up=True,
**factory_kwargs

View File

@ -268,11 +268,7 @@ class ResBlock(TimestepBlock):
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = emb_out.movedim(1, 2)
try:
h = h + emb_out
except:
emb_out = emb_out.movedim(1, 2)
h = h + emb_out
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h