mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
It causes crashes even without pytorch attention for big sizes, and for
resonable sizes it is significantly faster.
This reverts commit 1cd6cd6080.
735 lines
27 KiB
Python
735 lines
27 KiB
Python
# pytorch_diffusion + derived encoder decoder
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import logging
|
|
|
|
from comfy import model_management
|
|
import comfy.ops
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
if model_management.xformers_enabled_vae():
|
|
import xformers
|
|
import xformers.ops
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
|
From Fairseq.
|
|
Build sinusoidal embeddings.
|
|
This matches the implementation in tensor2tensor, but differs slightly
|
|
from the description in Section 3.5 of "Attention Is All You Need".
|
|
"""
|
|
assert len(timesteps.shape) == 1
|
|
|
|
half_dim = embedding_dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
|
emb = emb.to(device=timesteps.device)
|
|
emb = timesteps.float()[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
|
return emb
|
|
|
|
|
|
def nonlinearity(x):
|
|
# swish
|
|
return torch.nn.functional.silu(x)
|
|
|
|
|
|
def Normalize(in_channels, num_groups=32):
|
|
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
|
|
class VideoConv3d(nn.Module):
|
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
|
super().__init__()
|
|
|
|
self.padding_mode = padding_mode
|
|
if padding != 0:
|
|
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
|
|
else:
|
|
kwargs["padding"] = padding
|
|
|
|
self.padding = padding
|
|
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
|
|
def forward(self, x):
|
|
if self.padding != 0:
|
|
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
|
|
return self.conv(x)
|
|
|
|
def interpolate_up(x, scale_factor):
|
|
try:
|
|
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
|
except: #operation not implemented for bf16
|
|
orig_shape = list(x.shape)
|
|
out_shape = orig_shape[:2]
|
|
for i in range(len(orig_shape) - 2):
|
|
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
|
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
|
split = 8
|
|
l = out.shape[1] // split
|
|
for i in range(0, out.shape[1], l):
|
|
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
|
return out
|
|
|
|
class Upsample(nn.Module):
|
|
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
self.scale_factor = scale_factor
|
|
|
|
if self.with_conv:
|
|
self.conv = conv_op(in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x):
|
|
scale_factor = self.scale_factor
|
|
if isinstance(scale_factor, (int, float)):
|
|
scale_factor = (scale_factor,) * (x.ndim - 2)
|
|
|
|
if x.ndim == 5 and scale_factor[0] > 1.0:
|
|
t = x.shape[2]
|
|
if t > 1:
|
|
a, b = x.split((1, t - 1), dim=2)
|
|
del x
|
|
b = interpolate_up(b, scale_factor)
|
|
else:
|
|
a = x
|
|
|
|
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
|
if t > 1:
|
|
x = torch.cat((a, b), dim=2)
|
|
else:
|
|
x = a
|
|
else:
|
|
x = interpolate_up(x, scale_factor)
|
|
if self.with_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
self.conv = conv_op(in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=0)
|
|
|
|
def forward(self, x):
|
|
if self.with_conv:
|
|
if x.ndim == 4:
|
|
pad = (0, 1, 0, 1)
|
|
mode = "constant"
|
|
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
|
elif x.ndim == 5:
|
|
pad = (1, 1, 1, 1, 2, 0)
|
|
mode = "replicate"
|
|
x = torch.nn.functional.pad(x, pad, mode=mode)
|
|
x = self.conv(x)
|
|
else:
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
|
return x
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
|
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.swish = torch.nn.SiLU(inplace=True)
|
|
self.norm1 = norm_op(in_channels)
|
|
self.conv1 = conv_op(in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
if temb_channels > 0:
|
|
self.temb_proj = ops.Linear(temb_channels,
|
|
out_channels)
|
|
self.norm2 = norm_op(out_channels)
|
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
|
self.conv2 = conv_op(out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = conv_op(in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
else:
|
|
self.nin_shortcut = conv_op(in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
def forward(self, x, temb=None):
|
|
h = x
|
|
h = self.norm1(h)
|
|
h = self.swish(h)
|
|
h = self.conv1(h)
|
|
|
|
if temb is not None:
|
|
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
|
|
|
h = self.norm2(h)
|
|
h = self.swish(h)
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
|
|
return x+h
|
|
|
|
def slice_attention(q, k, v):
|
|
r1 = torch.zeros_like(k, device=q.device)
|
|
scale = (int(q.shape[-1])**(-0.5))
|
|
|
|
mem_free_total = model_management.get_free_memory(q.device)
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
modifier = 3 if q.element_size() == 2 else 2.5
|
|
mem_required = tensor_size * modifier
|
|
steps = 1
|
|
|
|
if mem_required > mem_free_total:
|
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
while True:
|
|
try:
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
for i in range(0, q.shape[1], slice_size):
|
|
end = i + slice_size
|
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
|
|
|
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
|
del s1
|
|
|
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
|
del s2
|
|
break
|
|
except model_management.OOM_EXCEPTION as e:
|
|
model_management.soft_empty_cache(True)
|
|
steps *= 2
|
|
if steps > 128:
|
|
raise e
|
|
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
|
|
|
return r1
|
|
|
|
def normal_attention(q, k, v):
|
|
# compute attention
|
|
orig_shape = q.shape
|
|
b = orig_shape[0]
|
|
c = orig_shape[1]
|
|
|
|
q = q.reshape(b, c, -1)
|
|
q = q.permute(0, 2, 1) # b,hw,c
|
|
k = k.reshape(b, c, -1) # b,c,hw
|
|
v = v.reshape(b, c, -1)
|
|
|
|
r1 = slice_attention(q, k, v)
|
|
h_ = r1.reshape(orig_shape)
|
|
del r1
|
|
return h_
|
|
|
|
def xformers_attention(q, k, v):
|
|
# compute attention
|
|
orig_shape = q.shape
|
|
B = orig_shape[0]
|
|
C = orig_shape[1]
|
|
q, k, v = map(
|
|
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
try:
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
|
out = out.transpose(1, 2).reshape(orig_shape)
|
|
except NotImplementedError:
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
|
return out
|
|
|
|
def pytorch_attention(q, k, v):
|
|
# compute attention
|
|
orig_shape = q.shape
|
|
B = orig_shape[0]
|
|
C = orig_shape[1]
|
|
q, k, v = map(
|
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
try:
|
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
out = out.transpose(2, 3).reshape(orig_shape)
|
|
except model_management.OOM_EXCEPTION:
|
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
|
return out
|
|
|
|
|
|
def vae_attention():
|
|
if model_management.xformers_enabled_vae():
|
|
logging.info("Using xformers attention in VAE")
|
|
return xformers_attention
|
|
elif model_management.pytorch_attention_enabled():
|
|
logging.info("Using pytorch attention in VAE")
|
|
return pytorch_attention
|
|
else:
|
|
logging.info("Using split attention in VAE")
|
|
return normal_attention
|
|
|
|
class AttnBlock(nn.Module):
|
|
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = norm_op(in_channels)
|
|
self.q = conv_op(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
self.k = conv_op(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
self.v = conv_op(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
self.proj_out = conv_op(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
self.optimized_attention = vae_attention()
|
|
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
h_ = self.optimized_attention(q, k, v)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x+h_
|
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
|
|
return AttnBlock(in_channels, conv_op=conv_op)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
|
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
|
super().__init__()
|
|
if use_linear_attn: attn_type = "linear"
|
|
self.ch = ch
|
|
self.temb_ch = self.ch*4
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
|
|
self.use_timestep = use_timestep
|
|
if self.use_timestep:
|
|
# timestep embedding
|
|
self.temb = nn.Module()
|
|
self.temb.dense = nn.ModuleList([
|
|
ops.Linear(self.ch,
|
|
self.temb_ch),
|
|
ops.Linear(self.temb_ch,
|
|
self.temb_ch),
|
|
])
|
|
|
|
# downsampling
|
|
self.conv_in = ops.Conv2d(in_channels,
|
|
self.ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
curr_res = resolution
|
|
in_ch_mult = (1,)+tuple(ch_mult)
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = ch*in_ch_mult[i_level]
|
|
block_out = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(ResnetBlock(in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if i_level != self.num_resolutions-1:
|
|
down.downsample = Downsample(block_in, resamp_with_conv)
|
|
curr_res = curr_res // 2
|
|
self.down.append(down)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
|
|
# upsampling
|
|
self.up = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_out = ch*ch_mult[i_level]
|
|
skip_in = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks+1):
|
|
if i_block == self.num_res_blocks:
|
|
skip_in = ch*in_ch_mult[i_level]
|
|
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
|
up = nn.Module()
|
|
up.block = block
|
|
up.attn = attn
|
|
if i_level != 0:
|
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
|
curr_res = curr_res * 2
|
|
self.up.insert(0, up) # prepend to get consistent order
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = ops.Conv2d(block_in,
|
|
out_ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x, t=None, context=None):
|
|
#assert x.shape[2] == x.shape[3] == self.resolution
|
|
if context is not None:
|
|
# assume aligned context, cat along channel axis
|
|
x = torch.cat((x, context), dim=1)
|
|
if self.use_timestep:
|
|
# timestep embedding
|
|
assert t is not None
|
|
temb = get_timestep_embedding(t, self.ch)
|
|
temb = self.temb.dense[0](temb)
|
|
temb = nonlinearity(temb)
|
|
temb = self.temb.dense[1](temb)
|
|
else:
|
|
temb = None
|
|
|
|
# downsampling
|
|
hs = [self.conv_in(x)]
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](hs[-1], temb)
|
|
if len(self.down[i_level].attn) > 0:
|
|
h = self.down[i_level].attn[i_block](h)
|
|
hs.append(h)
|
|
if i_level != self.num_resolutions-1:
|
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
|
|
|
# middle
|
|
h = hs[-1]
|
|
h = self.mid.block_1(h, temb)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h, temb)
|
|
|
|
# upsampling
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
for i_block in range(self.num_res_blocks+1):
|
|
h = self.up[i_level].block[i_block](
|
|
torch.cat([h, hs.pop()], dim=1), temb)
|
|
if len(self.up[i_level].attn) > 0:
|
|
h = self.up[i_level].attn[i_block](h)
|
|
if i_level != 0:
|
|
h = self.up[i_level].upsample(h)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
def get_last_layer(self):
|
|
return self.conv_out.weight
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
|
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
|
conv3d=False, time_compress=None,
|
|
**ignore_kwargs):
|
|
super().__init__()
|
|
if use_linear_attn: attn_type = "linear"
|
|
self.ch = ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
|
|
if conv3d:
|
|
conv_op = VideoConv3d
|
|
mid_attn_conv_op = ops.Conv3d
|
|
else:
|
|
conv_op = ops.Conv2d
|
|
mid_attn_conv_op = ops.Conv2d
|
|
|
|
# downsampling
|
|
self.conv_in = conv_op(in_channels,
|
|
self.ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
curr_res = resolution
|
|
in_ch_mult = (1,)+tuple(ch_mult)
|
|
self.in_ch_mult = in_ch_mult
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = ch*in_ch_mult[i_level]
|
|
block_out = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(ResnetBlock(in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
conv_op=conv_op))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if i_level != self.num_resolutions-1:
|
|
stride = 2
|
|
if time_compress is not None:
|
|
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
|
stride = (1, 2, 2)
|
|
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
|
curr_res = curr_res // 2
|
|
self.down.append(down)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
conv_op=conv_op)
|
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
|
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
conv_op=conv_op)
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = conv_op(block_in,
|
|
2*z_channels if double_z else z_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x):
|
|
# timestep embedding
|
|
temb = None
|
|
# downsampling
|
|
h = self.conv_in(x)
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](h, temb)
|
|
if len(self.down[i_level].attn) > 0:
|
|
h = self.down[i_level].attn[i_block](h)
|
|
if i_level != self.num_resolutions-1:
|
|
h = self.down[i_level].downsample(h)
|
|
|
|
# middle
|
|
h = self.mid.block_1(h, temb)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h, temb)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
|
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
|
conv_out_op=ops.Conv2d,
|
|
resnet_op=ResnetBlock,
|
|
attn_op=AttnBlock,
|
|
conv3d=False,
|
|
time_compress=None,
|
|
**ignorekwargs):
|
|
super().__init__()
|
|
self.ch = ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
self.give_pre_end = give_pre_end
|
|
self.tanh_out = tanh_out
|
|
|
|
if conv3d:
|
|
conv_op = VideoConv3d
|
|
conv_out_op = VideoConv3d
|
|
mid_attn_conv_op = ops.Conv3d
|
|
else:
|
|
conv_op = ops.Conv2d
|
|
mid_attn_conv_op = ops.Conv2d
|
|
|
|
# compute block_in and curr_res at lowest res
|
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
|
logging.debug("Working with z of shape {} = {} dimensions.".format(
|
|
self.z_shape, np.prod(self.z_shape)))
|
|
|
|
# z to block_in
|
|
self.conv_in = conv_op(z_channels,
|
|
block_in,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = resnet_op(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
conv_op=conv_op)
|
|
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
|
|
self.mid.block_2 = resnet_op(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
conv_op=conv_op)
|
|
|
|
# upsampling
|
|
self.up = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_out = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks+1):
|
|
block.append(resnet_op(in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
conv_op=conv_op))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(attn_op(block_in, conv_op=conv_op))
|
|
up = nn.Module()
|
|
up.block = block
|
|
up.attn = attn
|
|
if i_level != 0:
|
|
scale_factor = 2.0
|
|
if time_compress is not None:
|
|
if i_level > math.log2(time_compress):
|
|
scale_factor = (1.0, 2.0, 2.0)
|
|
|
|
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
|
|
curr_res = curr_res * 2
|
|
self.up.insert(0, up) # prepend to get consistent order
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = conv_out_op(block_in,
|
|
out_ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, z, **kwargs):
|
|
# timestep embedding
|
|
temb = None
|
|
|
|
# z to block_in
|
|
h = self.conv_in(z)
|
|
|
|
# middle
|
|
h = self.mid.block_1(h, temb, **kwargs)
|
|
h = self.mid.attn_1(h, **kwargs)
|
|
h = self.mid.block_2(h, temb, **kwargs)
|
|
|
|
# upsampling
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
for i_block in range(self.num_res_blocks+1):
|
|
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
|
if len(self.up[i_level].attn) > 0:
|
|
h = self.up[i_level].attn[i_block](h, **kwargs)
|
|
if i_level != 0:
|
|
h = self.up[i_level].upsample(h)
|
|
|
|
# end
|
|
if self.give_pre_end:
|
|
return h
|
|
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h, **kwargs)
|
|
if self.tanh_out:
|
|
h = torch.tanh(h)
|
|
return h
|