mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
1478 lines
51 KiB
Python
1478 lines
51 KiB
Python
from contextlib import nullcontext
|
|
from typing import Literal, Optional, Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
|
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
class SpatialNorm(nn.Module):
|
|
def __init__(
|
|
self,
|
|
f_channels: int,
|
|
zq_channels: int,
|
|
):
|
|
super().__init__()
|
|
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
|
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
|
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
|
f_size = f.shape[-2:]
|
|
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
|
norm_f = self.norm_layer(f)
|
|
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
|
return new_f
|
|
|
|
# partial implementation of diffusers's Attention for comfyui
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
cross_attention_dim: Optional[int] = None,
|
|
heads: int = 8,
|
|
kv_heads: Optional[int] = None,
|
|
dim_head: int = 64,
|
|
dropout: float = 0.0,
|
|
bias: bool = False,
|
|
upcast_attention: bool = False,
|
|
upcast_softmax: bool = False,
|
|
added_kv_proj_dim: Optional[int] = None,
|
|
added_proj_bias: Optional[bool] = True,
|
|
norm_num_groups: Optional[int] = None,
|
|
spatial_norm_dim: Optional[int] = None,
|
|
out_bias: bool = True,
|
|
scale_qk: bool = True,
|
|
only_cross_attention: bool = False,
|
|
eps: float = 1e-5,
|
|
rescale_output_factor: float = 1.0,
|
|
residual_connection: bool = False,
|
|
_from_deprecated_attn_block: bool = False,
|
|
out_dim: int = None,
|
|
out_context_dim: int = None,
|
|
context_pre_only=None,
|
|
pre_only=False,
|
|
is_causal: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
|
self.query_dim = query_dim
|
|
self.use_bias = bias
|
|
self.is_cross_attention = cross_attention_dim is not None
|
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
|
self.upcast_attention = upcast_attention
|
|
self.upcast_softmax = upcast_softmax
|
|
self.rescale_output_factor = rescale_output_factor
|
|
self.residual_connection = residual_connection
|
|
self.dropout = dropout
|
|
self.fused_projections = False
|
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
|
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
|
self.context_pre_only = context_pre_only
|
|
self.pre_only = pre_only
|
|
self.is_causal = is_causal
|
|
|
|
# we make use of this private variable to know whether this class is loaded
|
|
# with an deprecated state dict so that we can convert it on the fly
|
|
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
|
|
|
self.scale_qk = scale_qk
|
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
|
|
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
|
self.sliceable_head_dim = heads
|
|
|
|
self.added_kv_proj_dim = added_kv_proj_dim
|
|
self.only_cross_attention = only_cross_attention
|
|
|
|
if norm_num_groups is not None:
|
|
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
|
else:
|
|
self.group_norm = None
|
|
|
|
if spatial_norm_dim is not None:
|
|
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
|
else:
|
|
self.spatial_norm = None
|
|
|
|
self.norm_q = None
|
|
self.norm_k = None
|
|
|
|
self.norm_cross = None
|
|
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
|
|
if not self.only_cross_attention:
|
|
# only relevant for the `AddedKVProcessor` classes
|
|
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
|
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
|
else:
|
|
self.to_k = None
|
|
self.to_v = None
|
|
|
|
self.added_proj_bias = added_proj_bias
|
|
if self.added_kv_proj_dim is not None:
|
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
|
if self.context_pre_only is not None:
|
|
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
|
else:
|
|
self.add_q_proj = None
|
|
self.add_k_proj = None
|
|
self.add_v_proj = None
|
|
|
|
if not self.pre_only:
|
|
self.to_out = nn.ModuleList([])
|
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
|
self.to_out.append(nn.Dropout(dropout))
|
|
else:
|
|
self.to_out = None
|
|
|
|
if self.context_pre_only is not None and not self.context_pre_only:
|
|
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
|
else:
|
|
self.to_add_out = None
|
|
|
|
self.norm_added_q = None
|
|
self.norm_added_k = None
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
temb: Optional[torch.Tensor] = None,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
|
|
residual = hidden_states
|
|
if self.spatial_norm is not None:
|
|
hidden_states = self.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
|
|
|
|
if self.group_norm is not None:
|
|
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = self.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif self.norm_cross:
|
|
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = self.to_k(encoder_hidden_states)
|
|
value = self.to_v(encoder_hidden_states)
|
|
|
|
inner_dim = key.shape[-1]
|
|
head_dim = inner_dim // self.heads
|
|
|
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
|
if self.norm_q is not None:
|
|
query = self.norm_q(query)
|
|
if self.norm_k is not None:
|
|
key = self.norm_k(key)
|
|
|
|
hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
# linear proj
|
|
hidden_states = self.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = self.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if self.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / self.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor):
|
|
with torch.no_grad():
|
|
depth = weight_3d.size(2)
|
|
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
|
|
return weight_3d
|
|
|
|
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor):
|
|
with torch.no_grad():
|
|
bias_3d.copy_(bias_2d)
|
|
return bias_3d
|
|
|
|
|
|
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
|
|
weight_name = prefix + "weight"
|
|
bias_name = prefix + "bias"
|
|
if weight_name in state_dict:
|
|
weight_2d = state_dict[weight_name]
|
|
if weight_2d.dim() == 4:
|
|
weight_3d = inflate_weight_fn(
|
|
weight_2d=weight_2d,
|
|
weight_3d=layer.weight,
|
|
)
|
|
state_dict[weight_name] = weight_3d
|
|
else:
|
|
return state_dict
|
|
if bias_name in state_dict:
|
|
bias_2d = state_dict[bias_name]
|
|
if bias_2d.dim() == 1:
|
|
bias_3d = inflate_bias_fn(
|
|
bias_2d=bias_2d,
|
|
bias_3d=layer.bias,
|
|
)
|
|
state_dict[bias_name] = bias_3d
|
|
return state_dict
|
|
|
|
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|
input_dtype = x.dtype
|
|
if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)):
|
|
if x.ndim == 4:
|
|
x = rearrange(x, "b c h w -> b h w c")
|
|
x = norm_layer(x)
|
|
x = rearrange(x, "b h w c -> b c h w")
|
|
return x.to(input_dtype)
|
|
if x.ndim == 5:
|
|
x = rearrange(x, "b c t h w -> b t h w c")
|
|
x = norm_layer(x)
|
|
x = rearrange(x, "b t h w c -> b c t h w")
|
|
return x.to(input_dtype)
|
|
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
|
|
if x.ndim <= 4:
|
|
return norm_layer(x).to(input_dtype)
|
|
if x.ndim == 5:
|
|
t = x.size(2)
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
memory_occupy = x.numel() * x.element_size() / 1024**3
|
|
if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae
|
|
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
|
|
assert norm_layer.num_groups % num_chunks == 0
|
|
num_groups_per_chunk = norm_layer.num_groups // num_chunks
|
|
|
|
x = list(x.chunk(num_chunks, dim=1))
|
|
weights = norm_layer.weight.chunk(num_chunks, dim=0)
|
|
biases = norm_layer.bias.chunk(num_chunks, dim=0)
|
|
for i, (w, b) in enumerate(zip(weights, biases)):
|
|
x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps)
|
|
x[i] = x[i].to(input_dtype)
|
|
x = torch.cat(x, dim=1)
|
|
else:
|
|
x = norm_layer(x)
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
return x.to(input_dtype)
|
|
raise NotImplementedError
|
|
|
|
def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
|
|
"""Safe interpolate operation that handles Half precision for problematic modes"""
|
|
# Modes qui peuvent causer des problèmes avec Half precision
|
|
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
|
|
|
|
if mode in problematic_modes:
|
|
try:
|
|
return F.interpolate(
|
|
x,
|
|
size=size,
|
|
scale_factor=scale_factor,
|
|
mode=mode,
|
|
align_corners=align_corners,
|
|
recompute_scale_factor=recompute_scale_factor
|
|
)
|
|
except RuntimeError as e:
|
|
if ("not implemented for 'Half'" in str(e) or
|
|
"compute_indices_weights" in str(e)):
|
|
original_dtype = x.dtype
|
|
return F.interpolate(
|
|
x.float(),
|
|
size=size,
|
|
scale_factor=scale_factor,
|
|
mode=mode,
|
|
align_corners=align_corners,
|
|
recompute_scale_factor=recompute_scale_factor
|
|
).to(original_dtype)
|
|
else:
|
|
raise e
|
|
else:
|
|
# Pour 'nearest' et autres modes compatibles, pas de fix nécessaire
|
|
return F.interpolate(
|
|
x,
|
|
size=size,
|
|
scale_factor=scale_factor,
|
|
mode=mode,
|
|
align_corners=align_corners,
|
|
recompute_scale_factor=recompute_scale_factor
|
|
)
|
|
|
|
_receptive_field_t = Literal["half", "full"]
|
|
|
|
class InflatedCausalConv3d(nn.Conv3d):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
inflation_mode,
|
|
**kwargs,
|
|
):
|
|
self.inflation_mode = inflation_mode
|
|
self.memory = None
|
|
super().__init__(*args, **kwargs)
|
|
self.temporal_padding = self.padding[0]
|
|
self.padding = (0, *self.padding[1:])
|
|
self.memory_limit = float("inf")
|
|
|
|
def forward(
|
|
self,
|
|
input,
|
|
):
|
|
return super().forward(input)
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
|
|
super()._load_from_state_dict(
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
)
|
|
|
|
class Upsample3D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
out_channels = None,
|
|
inflation_mode = "tail",
|
|
temporal_up: bool = False,
|
|
spatial_up: bool = True,
|
|
slicing: bool = False,
|
|
interpolate = True,
|
|
name: str = "conv",
|
|
use_conv_transpose = False,
|
|
use_conv: bool = False,
|
|
padding = 1,
|
|
bias = True,
|
|
kernel_size = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.interpolate = interpolate
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv_transpose = use_conv_transpose
|
|
self.use_conv = use_conv
|
|
self.name = name
|
|
|
|
self.conv = None
|
|
if use_conv_transpose:
|
|
if kernel_size is None:
|
|
kernel_size = 4
|
|
self.conv = nn.ConvTranspose2d(
|
|
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
|
)
|
|
elif use_conv:
|
|
if kernel_size is None:
|
|
kernel_size = 3
|
|
self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
|
|
|
conv = self.conv if self.name == "conv" else self.Conv2d_0
|
|
|
|
assert type(conv) is not nn.ConvTranspose2d
|
|
# Note: lora_layer is not passed into constructor in the original implementation.
|
|
# So we make a simplification.
|
|
conv = InflatedCausalConv3d(
|
|
self.channels,
|
|
self.out_channels,
|
|
3,
|
|
padding=1,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
self.temporal_up = temporal_up
|
|
self.spatial_up = spatial_up
|
|
self.temporal_ratio = 2 if temporal_up else 1
|
|
self.spatial_ratio = 2 if spatial_up else 1
|
|
self.slicing = slicing
|
|
|
|
assert not self.interpolate
|
|
# [Override] MAGViT v2 implementation
|
|
if not self.interpolate:
|
|
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
|
|
self.upscale_conv = nn.Conv3d(
|
|
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
|
|
)
|
|
identity = (
|
|
torch.eye(self.channels)
|
|
.repeat(upscale_ratio, 1)
|
|
.reshape_as(self.upscale_conv.weight)
|
|
)
|
|
self.upscale_conv.weight.data.copy_(identity)
|
|
|
|
if self.name == "conv":
|
|
self.conv = conv
|
|
else:
|
|
self.Conv2d_0 = conv
|
|
|
|
self.norm = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
**kwargs,
|
|
) -> torch.FloatTensor:
|
|
assert hidden_states.shape[1] == self.channels
|
|
|
|
if hasattr(self, "norm") and self.norm is not None:
|
|
# [Overridden] change to causal norm.
|
|
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
|
|
|
|
if self.use_conv_transpose:
|
|
return self.conv(hidden_states)
|
|
|
|
if self.slicing:
|
|
split_size = hidden_states.size(2) // 2
|
|
hidden_states = list(
|
|
hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2)
|
|
)
|
|
else:
|
|
hidden_states = [hidden_states]
|
|
|
|
for i in range(len(hidden_states)):
|
|
hidden_states[i] = self.upscale_conv(hidden_states[i])
|
|
hidden_states[i] = rearrange(
|
|
hidden_states[i],
|
|
"b (x y z c) f h w -> b c (f z) (h x) (w y)",
|
|
x=self.spatial_ratio,
|
|
y=self.spatial_ratio,
|
|
z=self.temporal_ratio,
|
|
)
|
|
|
|
if not self.slicing:
|
|
hidden_states = hidden_states[0]
|
|
|
|
if self.use_conv:
|
|
if self.name == "conv":
|
|
hidden_states = self.conv(hidden_states)
|
|
else:
|
|
hidden_states = self.Conv2d_0(hidden_states)
|
|
|
|
if not self.slicing:
|
|
return hidden_states
|
|
else:
|
|
return torch.cat(hidden_states, dim=2)
|
|
|
|
|
|
class Downsample3D(nn.Module):
|
|
"""A 3D downsampling layer with an optional convolution."""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
out_channels = None,
|
|
inflation_mode = "tail",
|
|
spatial_down: bool = False,
|
|
temporal_down: bool = False,
|
|
name: str = "conv",
|
|
kernel_size=3,
|
|
use_conv: bool = False,
|
|
padding = 1,
|
|
bias=True,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.padding = padding
|
|
self.name = name
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.temporal_down = temporal_down
|
|
self.spatial_down = spatial_down
|
|
|
|
self.temporal_ratio = 2 if temporal_down else 1
|
|
self.spatial_ratio = 2 if spatial_down else 1
|
|
|
|
self.temporal_kernel = 3 if temporal_down else 1
|
|
self.spatial_kernel = 3 if spatial_down else 1
|
|
|
|
if use_conv:
|
|
conv = InflatedCausalConv3d(
|
|
self.channels,
|
|
self.out_channels,
|
|
kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel),
|
|
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
|
padding=(
|
|
1 if self.temporal_down else 0,
|
|
self.padding if self.spatial_down else 0,
|
|
self.padding if self.spatial_down else 0,
|
|
),
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
conv = nn.AvgPool3d(
|
|
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
|
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
|
)
|
|
|
|
self.conv = conv
|
|
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
**kwargs,
|
|
) -> torch.FloatTensor:
|
|
|
|
assert hidden_states.shape[1] == self.channels
|
|
|
|
if hasattr(self, "norm") and self.norm is not None:
|
|
# [Overridden] change to causal norm.
|
|
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
|
|
|
|
if self.use_conv and self.padding == 0 and self.spatial_down:
|
|
pad = (0, 1, 0, 1)
|
|
hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0)
|
|
|
|
assert hidden_states.shape[1] == self.channels
|
|
|
|
hidden_states = self.conv(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ResnetBlock3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: Optional[int] = None,
|
|
conv_shortcut: bool = False,
|
|
dropout: float = 0.0,
|
|
temb_channels: int = 512,
|
|
groups: int = 32,
|
|
groups_out: Optional[int] = None,
|
|
eps: float = 1e-6,
|
|
non_linearity: str = "swish",
|
|
time_embedding_norm: str = "default",
|
|
output_scale_factor: float = 1.0,
|
|
skip_time_act: bool = False,
|
|
use_in_shortcut: Optional[bool] = None,
|
|
up: bool = False,
|
|
down: bool = False,
|
|
conv_shortcut_bias: bool = True,
|
|
conv_2d_out_channels: Optional[int] = None,
|
|
inflation_mode = "tail",
|
|
time_receptive_field: _receptive_field_t = "half",
|
|
slicing: bool = False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.up = up
|
|
self.down = down
|
|
self.in_channels = in_channels
|
|
self.out_channels = in_channels if out_channels is None else out_channels
|
|
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
|
self.use_in_shortcut = use_in_shortcut
|
|
self.output_scale_factor = output_scale_factor
|
|
self.skip_time_act = skip_time_act
|
|
self.nonlinearity = nn.SiLU()
|
|
if temb_channels is not None:
|
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
|
else:
|
|
self.time_emb_proj = None
|
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
if groups_out is None:
|
|
groups_out = groups
|
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
|
self.use_in_shortcut = self.in_channels != out_channels
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.conv1 = InflatedCausalConv3d(
|
|
self.in_channels,
|
|
self.out_channels,
|
|
kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3),
|
|
stride=1,
|
|
padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1),
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
self.conv2 = InflatedCausalConv3d(
|
|
self.out_channels,
|
|
conv_2d_out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
if self.up:
|
|
self.upsample = Upsample3D(
|
|
self.in_channels,
|
|
use_conv=False,
|
|
inflation_mode=inflation_mode,
|
|
slicing=slicing,
|
|
)
|
|
elif self.down:
|
|
self.downsample = Downsample3D(
|
|
self.in_channels,
|
|
use_conv=False,
|
|
padding=1,
|
|
name="op",
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
if self.use_in_shortcut:
|
|
self.conv_shortcut = InflatedCausalConv3d(
|
|
self.in_channels,
|
|
conv_2d_out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=True,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
def forward(
|
|
self, input_tensor, temb, **kwargs
|
|
):
|
|
hidden_states = input_tensor
|
|
|
|
hidden_states = causal_norm_wrapper(self.norm1, hidden_states)
|
|
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
|
|
if self.upsample is not None:
|
|
if hidden_states.shape[0] >= 64:
|
|
input_tensor = input_tensor.contiguous()
|
|
hidden_states = hidden_states.contiguous()
|
|
input_tensor = self.upsample(input_tensor)
|
|
hidden_states = self.upsample(hidden_states)
|
|
elif self.downsample is not None:
|
|
input_tensor = self.downsample(input_tensor)
|
|
hidden_states = self.downsample(hidden_states)
|
|
|
|
hidden_states = self.conv1(hidden_states)
|
|
|
|
if self.time_emb_proj is not None:
|
|
if not self.skip_time_act:
|
|
temb = self.nonlinearity(temb)
|
|
temb = self.time_emb_proj(temb)[:, :, None, None]
|
|
|
|
if temb is not None:
|
|
hidden_states = hidden_states + temb
|
|
|
|
hidden_states = causal_norm_wrapper(self.norm2, hidden_states)
|
|
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.conv2(hidden_states)
|
|
|
|
if self.conv_shortcut is not None:
|
|
input_tensor = self.conv_shortcut(input_tensor)
|
|
|
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
|
|
|
return output_tensor
|
|
|
|
|
|
class DownEncoderBlock3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
output_scale_factor: float = 1.0,
|
|
add_downsample: bool = True,
|
|
downsample_padding: int = 1,
|
|
inflation_mode = "tail",
|
|
time_receptive_field: _receptive_field_t = "half",
|
|
temporal_down: bool = True,
|
|
spatial_down: bool = True,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
temporal_modules = []
|
|
|
|
for i in range(num_layers):
|
|
in_channels = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
# [Override] Replace module.
|
|
ResnetBlock3D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=None,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
)
|
|
temporal_modules.append(nn.Identity())
|
|
|
|
self.resnets = nn.ModuleList(resnets)
|
|
self.temporal_modules = nn.ModuleList(temporal_modules)
|
|
|
|
if add_downsample:
|
|
self.downsamplers = nn.ModuleList(
|
|
[
|
|
Downsample3D(
|
|
out_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
padding=downsample_padding,
|
|
name="op",
|
|
temporal_down=temporal_down,
|
|
spatial_down=spatial_down,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
]
|
|
)
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
**kwargs,
|
|
) -> torch.FloatTensor:
|
|
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
|
hidden_states = resnet(hidden_states, temb=None)
|
|
hidden_states = temporal(hidden_states)
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UpDecoderBlock3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default", # default, spatial
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
output_scale_factor: float = 1.0,
|
|
add_upsample: bool = True,
|
|
temb_channels: Optional[int] = None,
|
|
inflation_mode = "tail",
|
|
time_receptive_field: _receptive_field_t = "half",
|
|
temporal_up: bool = True,
|
|
spatial_up: bool = True,
|
|
slicing: bool = False,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
temporal_modules = []
|
|
|
|
for i in range(num_layers):
|
|
input_channels = in_channels if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
# [Override] Replace module.
|
|
ResnetBlock3D(
|
|
in_channels=input_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
slicing=slicing,
|
|
)
|
|
)
|
|
|
|
temporal_modules.append(nn.Identity())
|
|
|
|
self.resnets = nn.ModuleList(resnets)
|
|
self.temporal_modules = nn.ModuleList(temporal_modules)
|
|
|
|
if add_upsample:
|
|
# [Override] Replace module & use learnable upsample
|
|
self.upsamplers = nn.ModuleList(
|
|
[
|
|
Upsample3D(
|
|
out_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
temporal_up=temporal_up,
|
|
spatial_up=spatial_up,
|
|
interpolate=False,
|
|
inflation_mode=inflation_mode,
|
|
slicing=slicing,
|
|
)
|
|
]
|
|
)
|
|
else:
|
|
self.upsamplers = None
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
temb: Optional[torch.FloatTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
|
hidden_states = resnet(hidden_states, temb=None)
|
|
hidden_states = temporal(hidden_states)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UNetMidBlock3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default", # default, spatial
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
add_attention: bool = True,
|
|
attention_head_dim: int = 1,
|
|
output_scale_factor: float = 1.0,
|
|
inflation_mode = "tail",
|
|
time_receptive_field: _receptive_field_t = "half",
|
|
):
|
|
super().__init__()
|
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
|
self.add_attention = add_attention
|
|
|
|
# there is always at least one resnet
|
|
resnets = [
|
|
# [Override] Replace module.
|
|
ResnetBlock3D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
]
|
|
attentions = []
|
|
|
|
if attention_head_dim is None:
|
|
print(
|
|
f"It is not recommend to pass `attention_head_dim=None`. "
|
|
f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
|
)
|
|
attention_head_dim = in_channels
|
|
|
|
for _ in range(num_layers):
|
|
if self.add_attention:
|
|
attentions.append(
|
|
Attention(
|
|
in_channels,
|
|
heads=in_channels // attention_head_dim,
|
|
dim_head=attention_head_dim,
|
|
rescale_output_factor=output_scale_factor,
|
|
eps=resnet_eps,
|
|
norm_num_groups=(
|
|
resnet_groups if resnet_time_scale_shift == "default" else None
|
|
),
|
|
spatial_norm_dim=(
|
|
temb_channels if resnet_time_scale_shift == "spatial" else None
|
|
),
|
|
residual_connection=True,
|
|
bias=True,
|
|
upcast_softmax=True,
|
|
_from_deprecated_attn_block=True,
|
|
)
|
|
)
|
|
else:
|
|
attentions.append(None)
|
|
|
|
resnets.append(
|
|
ResnetBlock3D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
)
|
|
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states, temb=None):
|
|
video_length, frame_height, frame_width = hidden_states.size()[-3:]
|
|
hidden_states = self.resnets[0](hidden_states, temb)
|
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
if attn is not None:
|
|
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
|
hidden_states = attn(hidden_states, temb=temb)
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b f) c h w -> b c f h w", f=video_length
|
|
)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Encoder3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 3,
|
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",),
|
|
block_out_channels: Tuple[int, ...] = (64,),
|
|
layers_per_block: int = 2,
|
|
norm_num_groups: int = 32,
|
|
act_fn: str = "silu",
|
|
double_z: bool = True,
|
|
mid_block_add_attention=True,
|
|
# [Override] add extra_cond_dim, temporal down num
|
|
temporal_down_num: int = 2,
|
|
extra_cond_dim: int = None,
|
|
gradient_checkpoint: bool = False,
|
|
inflation_mode = "tail",
|
|
time_receptive_field: _receptive_field_t = "half",
|
|
):
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
self.temporal_down_num = temporal_down_num
|
|
|
|
self.conv_in = InflatedCausalConv3d(
|
|
in_channels,
|
|
block_out_channels[0],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
self.mid_block = None
|
|
self.down_blocks = nn.ModuleList([])
|
|
self.extra_cond_dim = extra_cond_dim
|
|
|
|
self.conv_extra_cond = nn.ModuleList([])
|
|
|
|
# down
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
# [Override] to support temporal down block design
|
|
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1
|
|
# Note: take the last ones
|
|
|
|
assert down_block_type == "DownEncoderBlock3D"
|
|
|
|
down_block = DownEncoderBlock3D(
|
|
num_layers=self.layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
add_downsample=not is_final_block,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
# Note: Don't know why set it as 0
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
temporal_down=is_temporal_down_block,
|
|
spatial_down=True,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
def zero_module(module):
|
|
# Zero out the parameters of a module and return it.
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
self.conv_extra_cond.append(
|
|
zero_module(
|
|
nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0)
|
|
)
|
|
if self.extra_cond_dim is not None and self.extra_cond_dim > 0
|
|
else None
|
|
)
|
|
|
|
# mid
|
|
self.mid_block = UNetMidBlock3D(
|
|
in_channels=block_out_channels[-1],
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
output_scale_factor=1,
|
|
resnet_time_scale_shift="default",
|
|
attention_head_dim=block_out_channels[-1],
|
|
resnet_groups=norm_num_groups,
|
|
temb_channels=None,
|
|
add_attention=mid_block_add_attention,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
|
|
# out
|
|
self.conv_norm_out = nn.GroupNorm(
|
|
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
|
|
)
|
|
self.conv_act = nn.SiLU()
|
|
|
|
conv_out_channels = 2 * out_channels if double_z else out_channels
|
|
self.conv_out = InflatedCausalConv3d(
|
|
block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode
|
|
)
|
|
|
|
self.gradient_checkpointing = gradient_checkpoint
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
extra_cond=None,
|
|
) -> torch.FloatTensor:
|
|
r"""The forward method of the `Encoder` class."""
|
|
sample = self.conv_in(sample)
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
# down
|
|
# [Override] add extra block and extra cond
|
|
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(down_block), sample, use_reentrant=False
|
|
)
|
|
if extra_block is not None:
|
|
sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:])
|
|
|
|
# middle
|
|
sample = self.mid_block(sample)
|
|
|
|
else:
|
|
# down
|
|
# [Override] add extra block and extra cond
|
|
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
|
|
sample = down_block(sample)
|
|
if extra_block is not None:
|
|
sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:])
|
|
|
|
# middle
|
|
sample = self.mid_block(sample)
|
|
|
|
# post-process
|
|
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
return sample
|
|
|
|
|
|
class Decoder3D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 3,
|
|
up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",),
|
|
block_out_channels: Tuple[int, ...] = (64,),
|
|
layers_per_block: int = 2,
|
|
norm_num_groups: int = 32,
|
|
act_fn: str = "silu",
|
|
norm_type: str = "group", # group, spatial
|
|
mid_block_add_attention=True,
|
|
# [Override] add temporal up block
|
|
inflation_mode = "tail",
|
|
time_receptive_field: _receptive_field_t = "half",
|
|
temporal_up_num: int = 2,
|
|
slicing_up_num: int = 0,
|
|
gradient_checkpoint: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
self.temporal_up_num = temporal_up_num
|
|
|
|
self.conv_in = InflatedCausalConv3d(
|
|
in_channels,
|
|
block_out_channels[-1],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
|
|
self.mid_block = None
|
|
self.up_blocks = nn.ModuleList([])
|
|
|
|
temb_channels = in_channels if norm_type == "spatial" else None
|
|
|
|
# mid
|
|
self.mid_block = UNetMidBlock3D(
|
|
in_channels=block_out_channels[-1],
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
output_scale_factor=1,
|
|
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
|
attention_head_dim=block_out_channels[-1],
|
|
resnet_groups=norm_num_groups,
|
|
temb_channels=temb_channels,
|
|
add_attention=mid_block_add_attention,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
|
|
# up
|
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
output_channel = reversed_block_out_channels[0]
|
|
print(f"slicing_up_num: {slicing_up_num}")
|
|
for i, up_block_type in enumerate(up_block_types):
|
|
prev_output_channel = output_channel
|
|
output_channel = reversed_block_out_channels[i]
|
|
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
is_temporal_up_block = i < self.temporal_up_num
|
|
is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num
|
|
# Note: Keep symmetric
|
|
|
|
assert up_block_type == "UpDecoderBlock3D"
|
|
up_block = UpDecoderBlock3D(
|
|
num_layers=self.layers_per_block + 1,
|
|
in_channels=prev_output_channel,
|
|
out_channels=output_channel,
|
|
add_upsample=not is_final_block,
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
resnet_time_scale_shift=norm_type,
|
|
temb_channels=temb_channels,
|
|
temporal_up=is_temporal_up_block,
|
|
slicing=is_slicing_up_block,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
self.up_blocks.append(up_block)
|
|
prev_output_channel = output_channel
|
|
|
|
# out
|
|
if norm_type == "spatial":
|
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
|
else:
|
|
self.conv_norm_out = nn.GroupNorm(
|
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
|
)
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = InflatedCausalConv3d(
|
|
block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode
|
|
)
|
|
|
|
self.gradient_checkpointing = gradient_checkpoint
|
|
|
|
# Note: Just copy from Decoder.
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
latent_embeds: Optional[torch.FloatTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
|
|
sample = self.conv_in(sample)
|
|
|
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
|
# middle
|
|
sample = self.mid_block(sample, latent_embeds)
|
|
sample = sample.to(upscale_dtype)
|
|
|
|
# up
|
|
for up_block in self.up_blocks:
|
|
sample = up_block(sample, latent_embeds)
|
|
|
|
# post-process
|
|
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
return sample
|
|
|
|
class VideoAutoencoderKL(nn.Module):
|
|
"""
|
|
We simply inherit the model code from diffusers
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 3,
|
|
layers_per_block: int = 2,
|
|
act_fn: str = "silu",
|
|
latent_channels: int = 16,
|
|
norm_num_groups: int = 32,
|
|
attention: bool = True,
|
|
temporal_scale_num: int = 2,
|
|
slicing_up_num: int = 0,
|
|
gradient_checkpoint: bool = False,
|
|
inflation_mode = "pad",
|
|
time_receptive_field: _receptive_field_t = "full",
|
|
use_quant_conv: bool = False,
|
|
use_post_quant_conv: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
|
block_out_channels = (128, 256, 512, 512)
|
|
down_block_types = ("DownEncoderBlock3D",) * 4
|
|
up_block_types = ("UpDecoderBlock3D",) * 4
|
|
super().__init__()
|
|
|
|
# pass init params to Encoder
|
|
self.encoder = Encoder3D(
|
|
in_channels=in_channels,
|
|
out_channels=latent_channels,
|
|
down_block_types=down_block_types,
|
|
block_out_channels=block_out_channels,
|
|
layers_per_block=layers_per_block,
|
|
act_fn=act_fn,
|
|
norm_num_groups=norm_num_groups,
|
|
double_z=True,
|
|
extra_cond_dim=extra_cond_dim,
|
|
# [Override] add temporal_down_num parameter
|
|
temporal_down_num=temporal_scale_num,
|
|
gradient_checkpoint=gradient_checkpoint,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
|
|
# pass init params to Decoder
|
|
self.decoder = Decoder3D(
|
|
in_channels=latent_channels,
|
|
out_channels=out_channels,
|
|
up_block_types=up_block_types,
|
|
block_out_channels=block_out_channels,
|
|
layers_per_block=layers_per_block,
|
|
norm_num_groups=norm_num_groups,
|
|
act_fn=act_fn,
|
|
# [Override] add temporal_up_num parameter
|
|
temporal_up_num=temporal_scale_num,
|
|
slicing_up_num=slicing_up_num,
|
|
gradient_checkpoint=gradient_checkpoint,
|
|
inflation_mode=inflation_mode,
|
|
time_receptive_field=time_receptive_field,
|
|
)
|
|
|
|
self.quant_conv = (
|
|
InflatedCausalConv3d(
|
|
in_channels=2 * latent_channels,
|
|
out_channels=2 * latent_channels,
|
|
kernel_size=1,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
if use_quant_conv
|
|
else None
|
|
)
|
|
self.post_quant_conv = (
|
|
InflatedCausalConv3d(
|
|
in_channels=latent_channels,
|
|
out_channels=latent_channels,
|
|
kernel_size=1,
|
|
inflation_mode=inflation_mode,
|
|
)
|
|
if use_post_quant_conv
|
|
else None
|
|
)
|
|
|
|
# A hacky way to remove attention.
|
|
if not attention:
|
|
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
|
|
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
|
|
|
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
|
h = self.slicing_encode(x)
|
|
posterior = DiagonalGaussianDistribution(h).sample()
|
|
|
|
if not return_dict:
|
|
return (posterior,)
|
|
|
|
return posterior
|
|
|
|
def decode(
|
|
self, z: torch.Tensor, return_dict: bool = True
|
|
):
|
|
decoded = self.slicing_decode(z)
|
|
|
|
if not return_dict:
|
|
return (decoded,)
|
|
|
|
return decoded
|
|
|
|
def _encode(
|
|
self, x: torch.Tensor
|
|
) -> torch.Tensor:
|
|
_x = x.to(self.device)
|
|
h = self.encoder(_x,)
|
|
if self.quant_conv is not None:
|
|
output = self.quant_conv(h)
|
|
else:
|
|
output = h
|
|
return output.to(x.device)
|
|
|
|
def _decode(
|
|
self, z: torch.Tensor
|
|
) -> torch.Tensor:
|
|
_z = z.to(self.device)
|
|
if self.post_quant_conv is not None:
|
|
_z = self.post_quant_conv(_z)
|
|
output = self.decoder(_z)
|
|
return output.to(z.device)
|
|
|
|
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self._encode(x)
|
|
|
|
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
|
|
return self._decode(z)
|
|
|
|
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def forward(
|
|
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs
|
|
):
|
|
# x: [b c t h w]
|
|
if mode == "encode":
|
|
h = self.encode(x)
|
|
return h.latent_dist
|
|
elif mode == "decode":
|
|
h = self.decode(x)
|
|
return h.sample
|
|
else:
|
|
h = self.encode(x)
|
|
h = self.decode(h.latent_dist.mode())
|
|
return h.sample
|
|
|
|
def load_state_dict(self, state_dict, strict=False):
|
|
# Newer version of diffusers changed the model keys,
|
|
# causing incompatibility with old checkpoints.
|
|
# They provided a method for conversion.
|
|
# We call conversion before loading state_dict.
|
|
convert_deprecated_attention_blocks = getattr(
|
|
self, "_convert_deprecated_attention_blocks", None
|
|
)
|
|
if callable(convert_deprecated_attention_blocks):
|
|
convert_deprecated_attention_blocks(state_dict)
|
|
return super().load_state_dict(state_dict, strict)
|
|
|
|
|
|
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
spatial_downsample_factor = 8,
|
|
temporal_downsample_factor = 4,
|
|
freeze_encoder = True,
|
|
**kwargs,
|
|
):
|
|
self.spatial_downsample_factor = spatial_downsample_factor
|
|
self.temporal_downsample_factor = temporal_downsample_factor
|
|
self.freeze_encoder = freeze_encoder
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, x: torch.FloatTensor):
|
|
with torch.no_grad() if self.freeze_encoder else nullcontext():
|
|
z, p = self.encode(x)
|
|
x = self.decode(z).sample
|
|
return x, z, p
|
|
|
|
def encode(self, x: torch.FloatTensor):
|
|
if x.ndim == 4:
|
|
x = x.unsqueeze(2)
|
|
p = super().encode(x).latent_dist
|
|
z = p.sample().squeeze(2)
|
|
return z, p
|
|
|
|
def decode(self, z: torch.FloatTensor):
|
|
if z.ndim == 4:
|
|
z = z.unsqueeze(2)
|
|
x = super().decode(z).sample.squeeze(2)
|
|
return x
|
|
|
|
def preprocess(self, x: torch.Tensor):
|
|
# x should in [B, C, T, H, W], [B, C, H, W]
|
|
assert x.ndim == 4 or x.size(2) % 4 == 1
|
|
return x
|
|
|
|
def postprocess(self, x: torch.Tensor):
|
|
# x should in [B, C, T, H, W], [B, C, H, W]
|
|
return x
|
|
|
|
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
|
# TODO
|
|
#set_norm_limit(norm_max_mem)
|
|
for m in self.modules():
|
|
if isinstance(m, InflatedCausalConv3d):
|
|
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|