mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
vae_refiner: roll the convolution through temporal II
Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. Added support for encoder, lowered to 1 latent frame to save more VRAM, made work for Hunyuan Image 3.0 (as code shared). Fixed names, cleaned up code.
This commit is contained in:
parent
d8858cb58b
commit
023036ef9d
@ -6,9 +6,8 @@ import comfy.ops
|
|||||||
import comfy.ldm.models.autoencoder
|
import comfy.ldm.models.autoencoder
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
class NoPadConv3d(nn.Module):
|
||||||
class SpatialPadConv3d(nn.Module):
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
@ -16,6 +15,28 @@ class SpatialPadConv3d(nn.Module):
|
|||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||||
|
|
||||||
|
x = xl[0]
|
||||||
|
xl.clear()
|
||||||
|
|
||||||
|
if conv_carry_out is not None:
|
||||||
|
to_push = x[:, :, -2:, :, :].clone()
|
||||||
|
conv_carry_out.append(to_push)
|
||||||
|
|
||||||
|
if isinstance(op, NoPadConv3d):
|
||||||
|
if conv_carry_in is None:
|
||||||
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||||
|
else:
|
||||||
|
carry_len = conv_carry_in[0].shape[2]
|
||||||
|
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||||
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||||
|
|
||||||
|
out = op(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -37,11 +58,12 @@ class DnSmpl(nn.Module):
|
|||||||
self.tds = tds
|
self.tds = tds
|
||||||
self.gs = fct * ic // oc
|
self.gs = fct * ic // oc
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
r1 = 2 if self.tds else 1
|
r1 = 2 if self.tds else 1
|
||||||
h = self.conv(x)
|
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
|
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||||
|
|
||||||
if self.tds and self.refiner_vae:
|
|
||||||
hf = h[:, :, :1, :, :]
|
hf = h[:, :, :1, :, :]
|
||||||
b, c, f, ht, wd = hf.shape
|
b, c, f, ht, wd = hf.shape
|
||||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||||
@ -49,14 +71,7 @@ class DnSmpl(nn.Module):
|
|||||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||||
hf = torch.cat([hf, hf], dim=1)
|
hf = torch.cat([hf, hf], dim=1)
|
||||||
|
|
||||||
hn = h[:, :, 1:, :, :]
|
h = h[:, :, 1:, :, :]
|
||||||
b, c, frms, ht, wd = hn.shape
|
|
||||||
nf = frms // r1
|
|
||||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
|
||||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
|
||||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
|
||||||
|
|
||||||
h = torch.cat([hf, hn], dim=2)
|
|
||||||
|
|
||||||
xf = x[:, :, :1, :, :]
|
xf = x[:, :, :1, :, :]
|
||||||
b, ci, f, ht, wd = xf.shape
|
b, ci, f, ht, wd = xf.shape
|
||||||
@ -64,20 +79,14 @@ class DnSmpl(nn.Module):
|
|||||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||||
B, C, T, H, W = xf.shape
|
B, C, T, H, W = xf.shape
|
||||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||||
|
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
|
||||||
|
if h.shape[2] == 0:
|
||||||
|
return hf + xf
|
||||||
|
|
||||||
xn = x[:, :, 1:, :, :]
|
|
||||||
b, ci, frms, ht, wd = xn.shape
|
|
||||||
nf = frms // r1
|
|
||||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
|
||||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
|
||||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
|
||||||
B, C, T, H, W = xn.shape
|
|
||||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
|
||||||
sc = torch.cat([xf, xn], dim=2)
|
|
||||||
else:
|
|
||||||
b, c, frms, ht, wd = h.shape
|
b, c, frms, ht, wd = h.shape
|
||||||
|
|
||||||
nf = frms // r1
|
nf = frms // r1
|
||||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
@ -85,33 +94,17 @@ class DnSmpl(nn.Module):
|
|||||||
|
|
||||||
b, ci, frms, ht, wd = x.shape
|
b, ci, frms, ht, wd = x.shape
|
||||||
nf = frms // r1
|
nf = frms // r1
|
||||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||||
B, C, T, H, W = sc.shape
|
B, C, T, H, W = x.shape
|
||||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||||
|
|
||||||
return h + sc
|
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||||
|
h = torch.cat([hf, h], dim=2)
|
||||||
|
x = torch.cat([xf, x], dim=2)
|
||||||
|
|
||||||
def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None):
|
return h + x
|
||||||
|
|
||||||
x = xl[0]
|
|
||||||
xl.clear()
|
|
||||||
|
|
||||||
if conv_carry_out is not None:
|
|
||||||
to_push = x[:, :, -2:, :, :].clone()
|
|
||||||
conv_carry_out.append(to_push)
|
|
||||||
|
|
||||||
if isinstance(op, SpatialPadConv3d):
|
|
||||||
if conv_carry_in is None:
|
|
||||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
|
||||||
else:
|
|
||||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
|
||||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 0, 0), mode = 'replicate')
|
|
||||||
|
|
||||||
out = op(x)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class UpSmpl(nn.Module):
|
class UpSmpl(nn.Module):
|
||||||
@ -126,10 +119,9 @@ class UpSmpl(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
r1 = 2 if self.tus else 1
|
r1 = 2 if self.tus else 1
|
||||||
h = conv_carry([x], self.conv, conv_carry_in, conv_carry_out)
|
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
if self.tus and self.refiner_vae:
|
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||||
if conv_carry_in is None:
|
|
||||||
hf = h[:, :, :1, :, :]
|
hf = h[:, :, :1, :, :]
|
||||||
b, c, f, ht, wd = hf.shape
|
b, c, f, ht, wd = hf.shape
|
||||||
nc = c // (2 * 2)
|
nc = c // (2 * 2)
|
||||||
@ -164,27 +156,28 @@ class UpSmpl(nn.Module):
|
|||||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
if conv_carry_in is None:
|
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||||
h = torch.cat([hf, h], dim=2)
|
h = torch.cat([hf, h], dim=2)
|
||||||
sc = torch.cat([xf, x], dim=2)
|
x = torch.cat([xf, x], dim=2)
|
||||||
else:
|
|
||||||
sc = x
|
|
||||||
else:
|
|
||||||
#FIXME: make this work
|
|
||||||
b, c, frms, ht, wd = h.shape
|
|
||||||
nc = c // (r1 * 2 * 2)
|
|
||||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
|
||||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
|
||||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
|
||||||
|
|
||||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
return h + x
|
||||||
b, c, frms, ht, wd = sc.shape
|
|
||||||
nc = c // (r1 * 2 * 2)
|
|
||||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
|
||||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
|
||||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
|
||||||
|
|
||||||
return h + sc
|
class HunyuanRefinerResnetBlock(ResnetBlock):
|
||||||
|
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
||||||
|
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
|
h = x
|
||||||
|
h = [ self.swish(self.norm1(x)) ]
|
||||||
|
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
|
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
||||||
|
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x+h
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
@ -197,7 +190,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = VideoConv3d
|
conv_op = NoPadConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -212,9 +205,8 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
|
||||||
conv_op=conv_op, norm_op=norm_op)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks)])
|
for j in range(num_res_blocks)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
@ -225,9 +217,9 @@ class Encoder(nn.Module):
|
|||||||
self.down.append(stage)
|
self.down.append(stage)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.norm_out = norm_op(ch)
|
self.norm_out = norm_op(ch)
|
||||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||||
@ -238,21 +230,46 @@ class Encoder(nn.Module):
|
|||||||
if not self.refiner_vae and x.shape[2] == 1:
|
if not self.refiner_vae and x.shape[2] == 1:
|
||||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||||
|
|
||||||
x = self.conv_in(x)
|
if self.refiner_vae:
|
||||||
|
xl = [x[:, :, :1, :, :]]
|
||||||
|
if x.shape[2] > 1:
|
||||||
|
xl += torch.split(x[:, :, 1:, :, :], self.ffactor_temporal, dim=2)
|
||||||
|
x = xl
|
||||||
|
else:
|
||||||
|
x = [x]
|
||||||
|
out = []
|
||||||
|
|
||||||
|
conv_carry_in = None
|
||||||
|
|
||||||
|
for i, x1 in enumerate(x):
|
||||||
|
conv_carry_out = []
|
||||||
|
if i == len(x) - 1:
|
||||||
|
conv_carry_out = None
|
||||||
|
x1 = [ x1 ]
|
||||||
|
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
for stage in self.down:
|
for stage in self.down:
|
||||||
for blk in stage.block:
|
for blk in stage.block:
|
||||||
x = blk(x)
|
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||||
if hasattr(stage, 'downsample'):
|
if hasattr(stage, 'downsample'):
|
||||||
x = stage.downsample(x)
|
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
out.append(x1)
|
||||||
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
|
if len(out) > 1:
|
||||||
|
out = torch.cat(out, dim=2)
|
||||||
|
else:
|
||||||
|
out = out[0]
|
||||||
|
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||||
|
del out
|
||||||
|
|
||||||
b, c, t, h, w = x.shape
|
b, c, t, h, w = x.shape
|
||||||
grp = c // (self.z_channels << 1)
|
grp = c // (self.z_channels << 1)
|
||||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||||
|
|
||||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||||
|
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
out = self.regul(out)[0]
|
out = self.regul(out)[0]
|
||||||
@ -266,23 +283,6 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
|
||||||
def __init__(self, in_channels, out_channels, conv_op=SpatialPadConv3d, norm_op=RMS_norm):
|
|
||||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=SpatialPadConv3d, norm_op=RMS_norm)
|
|
||||||
|
|
||||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
||||||
h = x
|
|
||||||
h = [ self.swish(self.norm1(x)) ]
|
|
||||||
h = conv_carry(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
|
||||||
|
|
||||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
|
||||||
h = conv_carry(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x = self.nin_shortcut(x)
|
|
||||||
|
|
||||||
return x+h
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
||||||
@ -294,7 +294,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = SpatialPadConv3d
|
conv_op = NoPadConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -339,16 +339,21 @@ class Decoder(nn.Module):
|
|||||||
# z = z.permute(0, 2, 1, 3, 4)
|
# z = z.permute(0, 2, 1, 3, 4)
|
||||||
# z = z[:, :, 1:]
|
# z = z[:, :, 1:]
|
||||||
|
|
||||||
x = conv_carry([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
if self.refiner_vae:
|
||||||
|
x = torch.split(x, 1, dim=2)
|
||||||
|
else:
|
||||||
|
x = [ x ]
|
||||||
|
out = []
|
||||||
|
|
||||||
conv_carry_in = None
|
conv_carry_in = None
|
||||||
|
|
||||||
x = torch.split(x, 2, dim=2)
|
|
||||||
out = []
|
|
||||||
|
|
||||||
for i, x1 in enumerate(x):
|
for i, x1 in enumerate(x):
|
||||||
conv_carry_out = []
|
conv_carry_out = []
|
||||||
|
if i == len(x) - 1:
|
||||||
|
conv_carry_out = None
|
||||||
for stage in self.up:
|
for stage in self.up:
|
||||||
for blk in stage.block:
|
for blk in stage.block:
|
||||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||||
@ -356,15 +361,19 @@ class Decoder(nn.Module):
|
|||||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||||
x1 = conv_carry(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||||
out.append(x1)
|
out.append(x1)
|
||||||
conv_carry_in = conv_carry_out
|
conv_carry_in = conv_carry_out
|
||||||
del x
|
del x
|
||||||
|
|
||||||
|
if len(out) > 1:
|
||||||
out = torch.cat(out, dim=2)
|
out = torch.cat(out, dim=2)
|
||||||
|
else:
|
||||||
|
out = out[0]
|
||||||
|
|
||||||
if not self.refiner_vae:
|
if not self.refiner_vae:
|
||||||
if z.shape[-3] == 1:
|
if z.shape[-3] == 1:
|
||||||
out = out[:, :, -1:]
|
out = out[:, :, -1:]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user