from contextlib import nullcontext from typing import Literal, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.attention_processor import Attention from einops import rearrange from model import safe_pad_operation from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.modules.attention import optimized_attention class SpatialNorm(nn.Module): def __init__( self, f_channels: int, zq_channels: int, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f # partial implementation of diffusers's Attention for comfyui class Attention(nn.Module): def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, kv_heads: Optional[int] = None, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, upcast_attention: bool = False, upcast_softmax: bool = False, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block: bool = False, out_dim: int = None, out_context_dim: int = None, context_pre_only=None, pre_only=False, is_causal: bool = False, ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim self.use_bias = bias self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only self.is_causal = is_causal # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None self.norm_q = None self.norm_k = None self.norm_cross = None self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) else: self.to_k = None self.to_v = None self.added_proj_bias = added_proj_bias if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) else: self.add_q_proj = None self.add_k_proj = None self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) else: self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) else: self.to_add_out = None self.norm_added_q = None self.norm_added_k = None def __call__( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: residual = hidden_states if self.spatial_norm is not None: hidden_states = self.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) if self.norm_q is not None: query = self.norm_q(query) if self.norm_k is not None: key = self.norm_k(key) hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if self.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / self.rescale_output_factor return hidden_states def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): """ Inflate a 2D convolution weight matrix to a 3D one. Parameters: weight_2d: The weight matrix of 2D conv to be inflated. weight_3d: The weight matrix of 3D conv to be initialized. inflation_mode: the mode of inflation """ assert inflation_mode in ["tail", "replicate"] assert weight_3d.shape[:2] == weight_2d.shape[:2] with torch.no_grad(): if inflation_mode == "replicate": depth = weight_3d.size(2) weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) else: weight_3d.fill_(0.0) weight_3d[:, :, -1].copy_(weight_2d) return weight_3d def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): """ Inflate a 2D convolution bias tensor to a 3D one Parameters: bias_2d: The bias tensor of 2D conv to be inflated. bias_3d: The bias tensor of 3D conv to be initialized. inflation_mode: Placeholder to align `inflate_weight`. """ assert bias_3d.shape == bias_2d.shape with torch.no_grad(): bias_3d.copy_(bias_2d) return bias_3d def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): """ the main function to inflated 2D parameters to 3D. """ weight_name = prefix + "weight" bias_name = prefix + "bias" if weight_name in state_dict: weight_2d = state_dict[weight_name] if weight_2d.dim() == 4: # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) weight_3d = inflate_weight_fn( weight_2d=weight_2d, weight_3d=layer.weight, inflation_mode=layer.inflation_mode, ) state_dict[weight_name] = weight_3d else: return state_dict # It's a 3d state dict, should not do inflation on both bias and weight. if bias_name in state_dict: bias_2d = state_dict[bias_name] if bias_2d.dim() == 1: # Assuming the 2D biases are 1D tensors (out_channels,) bias_3d = inflate_bias_fn( bias_2d=bias_2d, bias_3d=layer.bias, inflation_mode=layer.inflation_mode, ) state_dict[bias_name] = bias_3d return state_dict def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): if x.ndim == 4: x = rearrange(x, "b c h w -> b h w c") x = norm_layer(x) x = rearrange(x, "b h w c -> b c h w") return x.to(input_dtype) if x.ndim == 5: x = rearrange(x, "b c t h w -> b t h w c") x = norm_layer(x) x = rearrange(x, "b t h w c -> b c t h w") return x.to(input_dtype) if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): if x.ndim <= 4: return norm_layer(x).to(input_dtype) if x.ndim == 5: t = x.size(2) x = rearrange(x, "b c t h w -> (b t) c h w") memory_occupy = x.numel() * x.element_size() / 1024**3 if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) assert norm_layer.num_groups % num_chunks == 0 num_groups_per_chunk = norm_layer.num_groups // num_chunks x = list(x.chunk(num_chunks, dim=1)) weights = norm_layer.weight.chunk(num_chunks, dim=0) biases = norm_layer.bias.chunk(num_chunks, dim=0) for i, (w, b) in enumerate(zip(weights, biases)): x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) x[i] = x[i].to(input_dtype) x = torch.cat(x, dim=1) else: x = norm_layer(x) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x.to(input_dtype) raise NotImplementedError def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): """Safe interpolate operation that handles Half precision for problematic modes""" # Modes qui peuvent causer des problèmes avec Half precision problematic_modes = ['bilinear', 'bicubic', 'trilinear'] if mode in problematic_modes: try: return F.interpolate( x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ) except RuntimeError as e: if ("not implemented for 'Half'" in str(e) or "compute_indices_weights" in str(e)): original_dtype = x.dtype return F.interpolate( x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ).to(original_dtype) else: raise e else: # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire return F.interpolate( x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ) _receptive_field_t = Literal["half", "full"] class InflatedCausalConv3d(nn.Conv3d): def __init__( self, *args, inflation_mode, **kwargs, ): self.inflation_mode = inflation_mode self.memory = None super().__init__(*args, **kwargs) self.temporal_padding = self.padding[0] self.padding = (0, *self.padding[1:]) self.memory_limit = float("inf") def forward( self, input, ): return super().forward(input) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): if self.inflation_mode != "none": state_dict = modify_state_dict( self, state_dict, prefix, inflate_weight_fn=inflate_weight, inflate_bias_fn=inflate_bias, ) super()._load_from_state_dict( state_dict, prefix, local_metadata, (strict and self.inflation_mode == "none"), missing_keys, unexpected_keys, error_msgs, ) class Upsample3D(nn.Module): def __init__( self, channels, out_channels = None, inflation_mode = "tail", temporal_up: bool = False, spatial_up: bool = True, slicing: bool = False, interpolate = True, name: str = "conv", use_conv_transpose = False, use_conv: bool = False, padding = 1, bias = True, kernel_size = None, **kwargs, ): super().__init__() self.interpolate = interpolate self.channels = channels self.out_channels = out_channels or channels self.use_conv_transpose = use_conv_transpose self.use_conv = use_conv self.name = name self.conv = None if use_conv_transpose: if kernel_size is None: kernel_size = 4 self.conv = nn.ConvTranspose2d( channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias ) elif use_conv: if kernel_size is None: kernel_size = 3 self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) conv = self.conv if self.name == "conv" else self.Conv2d_0 assert type(conv) is not nn.ConvTranspose2d # Note: lora_layer is not passed into constructor in the original implementation. # So we make a simplification. conv = InflatedCausalConv3d( self.channels, self.out_channels, 3, padding=1, inflation_mode=inflation_mode, ) self.temporal_up = temporal_up self.spatial_up = spatial_up self.temporal_ratio = 2 if temporal_up else 1 self.spatial_ratio = 2 if spatial_up else 1 self.slicing = slicing assert not self.interpolate # [Override] MAGViT v2 implementation if not self.interpolate: upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio self.upscale_conv = nn.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) identity = ( torch.eye(self.channels) .repeat(upscale_ratio, 1) .reshape_as(self.upscale_conv.weight) ) self.upscale_conv.weight.data.copy_(identity) if self.name == "conv": self.conv = conv else: self.Conv2d_0 = conv self.norm = False def forward( self, hidden_states: torch.FloatTensor, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if hasattr(self, "norm") and self.norm is not None: # [Overridden] change to causal norm. hidden_states = causal_norm_wrapper(self.norm, hidden_states) if self.use_conv_transpose: return self.conv(hidden_states) if self.slicing: split_size = hidden_states.size(2) // 2 hidden_states = list( hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) ) else: hidden_states = [hidden_states] for i in range(len(hidden_states)): hidden_states[i] = self.upscale_conv(hidden_states[i]) hidden_states[i] = rearrange( hidden_states[i], "b (x y z c) f h w -> b c (f z) (h x) (w y)", x=self.spatial_ratio, y=self.spatial_ratio, z=self.temporal_ratio, ) if not self.slicing: hidden_states = hidden_states[0] if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states) else: hidden_states = self.Conv2d_0(hidden_states) if not self.slicing: return hidden_states else: return torch.cat(hidden_states, dim=2) class Downsample3D(nn.Module): """A 3D downsampling layer with an optional convolution.""" def __init__( self, channels, out_channels = None, inflation_mode = "tail", spatial_down: bool = False, temporal_down: bool = False, name: str = "conv", kernel_size=3, use_conv: bool = False, padding = 1, bias=True, **kwargs, ): super().__init__() self.padding = padding self.name = name self.channels = channels self.out_channels = out_channels or channels self.temporal_down = temporal_down self.spatial_down = spatial_down self.temporal_ratio = 2 if temporal_down else 1 self.spatial_ratio = 2 if spatial_down else 1 self.temporal_kernel = 3 if temporal_down else 1 self.spatial_kernel = 3 if spatial_down else 1 if use_conv: conv = InflatedCausalConv3d( self.channels, self.out_channels, kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), padding=( 1 if self.temporal_down else 0, self.padding if self.spatial_down else 0, self.padding if self.spatial_down else 0, ), inflation_mode=inflation_mode, ) else: assert self.channels == self.out_channels conv = nn.AvgPool3d( kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), ) self.conv = conv def forward( self, hidden_states: torch.FloatTensor, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if hasattr(self, "norm") and self.norm is not None: # [Overridden] change to causal norm. hidden_states = causal_norm_wrapper(self.norm, hidden_states) if self.use_conv and self.padding == 0 and self.spatial_down: pad = (0, 1, 0, 1) hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class ResnetBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: Optional[int] = None, conv_shortcut: bool = False, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, eps: float = 1e-6, non_linearity: str = "swish", time_embedding_norm: str = "default", output_scale_factor: float = 1.0, skip_time_act: bool = False, use_in_shortcut: Optional[bool] = None, up: bool = False, down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", slicing: bool = False, **kwargs, ): super().__init__() self.up = up self.down = down self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels conv_2d_out_channels = conv_2d_out_channels or out_channels self.use_in_shortcut = use_in_shortcut self.output_scale_factor = output_scale_factor self.skip_time_act = skip_time_act self.nonlinearity = nn.SiLU() if temb_channels is not None: self.time_emb_proj = nn.Linear(temb_channels, out_channels) else: self.time_emb_proj = None self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) if groups_out is None: groups_out = groups self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.use_in_shortcut = self.in_channels != out_channels self.dropout = torch.nn.Dropout(dropout) self.conv1 = InflatedCausalConv3d( self.in_channels, self.out_channels, kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), stride=1, padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), inflation_mode=inflation_mode, ) self.conv2 = InflatedCausalConv3d( self.out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) if self.up: self.upsample = Upsample3D( self.in_channels, use_conv=False, inflation_mode=inflation_mode, slicing=slicing, ) elif self.down: self.downsample = Downsample3D( self.in_channels, use_conv=False, padding=1, name="op", inflation_mode=inflation_mode, ) if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=True, inflation_mode=inflation_mode, ) def forward( self, input_tensor, temb, **kwargs ): hidden_states = input_tensor hidden_states = causal_norm_wrapper(self.norm1, hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None: hidden_states = hidden_states + temb hidden_states = causal_norm_wrapper(self.norm2, hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class DownEncoderBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", temporal_down: bool = True, spatial_down: bool = True, ): super().__init__() resnets = [] temporal_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) temporal_modules.append(nn.Identity()) self.resnets = nn.ModuleList(resnets) self.temporal_modules = nn.ModuleList(temporal_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", temporal_down=temporal_down, spatial_down=spatial_down, inflation_mode=inflation_mode, ) ] ) else: self.downsamplers = None def forward( self, hidden_states: torch.FloatTensor, **kwargs, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): hidden_states = resnet(hidden_states, temb=None) hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class UpDecoderBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up: bool = True, spatial_up: bool = True, slicing: bool = False, ): super().__init__() resnets = [] temporal_modules = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( # [Override] Replace module. ResnetBlock3D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, slicing=slicing, ) ) temporal_modules.append(nn.Identity()) self.resnets = nn.ModuleList(resnets) self.temporal_modules = nn.ModuleList(temporal_modules) if add_upsample: # [Override] Replace module & use learnable upsample self.upsamplers = nn.ModuleList( [ Upsample3D( out_channels, use_conv=True, out_channels=out_channels, temporal_up=temporal_up, spatial_up=spatial_up, interpolate=False, inflation_mode=inflation_mode, slicing=slicing, ) ] ) else: self.upsamplers = None def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): hidden_states = resnet(hidden_states, temb=None) hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class UNetMidBlock3D(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ] attentions = [] if attention_head_dim is None: print( f"It is not recommend to pass `attention_head_dim=None`. " f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=( resnet_groups if resnet_time_scale_shift == "default" else None ), spatial_norm_dim=( temb_channels if resnet_time_scale_shift == "spatial" else None ), residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): video_length, frame_height, frame_width = hidden_states.size()[-3:] hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = attn(hidden_states, temb=temb) hidden_states = rearrange( hidden_states, "(b f) c h w -> b c f h w", f=video_length ) hidden_states = resnet(hidden_states, temb) return hidden_states class Encoder3D(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, # [Override] add extra_cond_dim, temporal down num temporal_down_num: int = 2, extra_cond_dim: int = None, gradient_checkpoint: bool = False, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", ): super().__init__() self.layers_per_block = layers_per_block self.temporal_down_num = temporal_down_num self.conv_in = InflatedCausalConv3d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.mid_block = None self.down_blocks = nn.ModuleList([]) self.extra_cond_dim = extra_cond_dim self.conv_extra_cond = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 # [Override] to support temporal down block design is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 # Note: take the last ones assert down_block_type == "DownEncoderBlock3D" down_block = DownEncoderBlock3D( num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, # Note: Don't know why set it as 0 resnet_act_fn=act_fn, resnet_groups=norm_num_groups, temporal_down=is_temporal_down_block, spatial_down=True, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.down_blocks.append(down_block) def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module self.conv_extra_cond.append( zero_module( nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) ) if self.extra_cond_dim is not None and self.extra_cond_dim > 0 else None ) # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # out self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = InflatedCausalConv3d( block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode ) self.gradient_checkpointing = gradient_checkpoint def forward( self, sample: torch.FloatTensor, extra_cond=None, ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down # [Override] add extra block and extra cond for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, use_reentrant=False ) if extra_block is not None: sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) # middle sample = self.mid_block(sample) else: # down # [Override] add extra block and extra cond for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): sample = down_block(sample) if extra_block is not None: sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) # middle sample = self.mid_block(sample) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Decoder3D(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial mid_block_add_attention=True, # [Override] add temporal up block inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up_num: int = 2, slicing_up_num: int = 0, gradient_checkpoint: bool = False, ): super().__init__() self.layers_per_block = layers_per_block self.temporal_up_num = temporal_up_num self.conv_in = InflatedCausalConv3d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] print(f"slicing_up_num: {slicing_up_num}") for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 is_temporal_up_block = i < self.temporal_up_num is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num # Note: Keep symmetric assert up_block_type == "UpDecoderBlock3D" up_block = UpDecoderBlock3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, resnet_time_scale_shift=norm_type, temb_channels=temb_channels, temporal_up=is_temporal_up_block, slicing=is_slicing_up_block, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() self.conv_out = InflatedCausalConv3d( block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode ) self.gradient_checkpointing = gradient_checkpoint # Note: Just copy from Decoder. def forward( self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class VideoAutoencoderKL(nn.Module): """ We simply inherit the model code from diffusers """ def __init__( self, in_channels: int = 3, out_channels: int = 3, layers_per_block: int = 2, act_fn: str = "silu", latent_channels: int = 16, norm_num_groups: int = 32, attention: bool = True, temporal_scale_num: int = 2, slicing_up_num: int = 0, gradient_checkpoint: bool = False, inflation_mode = "pad", time_receptive_field: _receptive_field_t = "full", use_quant_conv: bool = False, use_post_quant_conv: bool = False, *args, **kwargs, ): extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None block_out_channels = (128, 256, 512, 512) down_block_types = ("DownEncoderBlock3D",) * 4 up_block_types = ("UpDecoderBlock3D",) * 4 super().__init__() # pass init params to Encoder self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, extra_cond_dim=extra_cond_dim, # [Override] add temporal_down_num parameter temporal_down_num=temporal_scale_num, gradient_checkpoint=gradient_checkpoint, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # pass init params to Decoder self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, # [Override] add temporal_up_num parameter temporal_up_num=temporal_scale_num, slicing_up_num=slicing_up_num, gradient_checkpoint=gradient_checkpoint, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.quant_conv = ( InflatedCausalConv3d( in_channels=2 * latent_channels, out_channels=2 * latent_channels, kernel_size=1, inflation_mode=inflation_mode, ) if use_quant_conv else None ) self.post_quant_conv = ( InflatedCausalConv3d( in_channels=latent_channels, out_channels=latent_channels, kernel_size=1, inflation_mode=inflation_mode, ) if use_post_quant_conv else None ) # A hacky way to remove attention. if not attention: self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) def encode(self, x: torch.FloatTensor, return_dict: bool = True): h = self.slicing_encode(x) posterior = DiagonalGaussianDistribution(h).sample() if not return_dict: return (posterior,) return posterior def decode( self, z: torch.Tensor, return_dict: bool = True ): decoded = self.slicing_decode(z) if not return_dict: return (decoded,) return decoded def _encode( self, x: torch.Tensor ) -> torch.Tensor: _x = x.to(self.device) h = self.encoder(_x,) if self.quant_conv is not None: output = self.quant_conv(h) else: output = h return output.to(x.device) def _decode( self, z: torch.Tensor ) -> torch.Tensor: _z = z.to(self.device) if self.post_quant_conv is not None: _z = self.post_quant_conv(_z) output = self.decoder(_z) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: return self._decode(z) def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError def forward( self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs ): # x: [b c t h w] if mode == "encode": h = self.encode(x) return h.latent_dist elif mode == "decode": h = self.decode(x) return h.sample else: h = self.encode(x) h = self.decode(h.latent_dist.mode()) return h.sample def load_state_dict(self, state_dict, strict=False): # Newer version of diffusers changed the model keys, # causing incompatibility with old checkpoints. # They provided a method for conversion. # We call conversion before loading state_dict. convert_deprecated_attention_blocks = getattr( self, "_convert_deprecated_attention_blocks", None ) if callable(convert_deprecated_attention_blocks): convert_deprecated_attention_blocks(state_dict) return super().load_state_dict(state_dict, strict) class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def __init__( self, *args, spatial_downsample_factor = 8, temporal_downsample_factor = 4, freeze_encoder = True, **kwargs, ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor self.freeze_encoder = freeze_encoder super().__init__(*args, **kwargs) def forward(self, x: torch.FloatTensor): with torch.no_grad() if self.freeze_encoder else nullcontext(): z, p = self.encode(x) x = self.decode(z).sample return x, z, p def encode(self, x: torch.FloatTensor): if x.ndim == 4: x = x.unsqueeze(2) p = super().encode(x).latent_dist z = p.sample().squeeze(2) return z, p def decode(self, z: torch.FloatTensor): if z.ndim == 4: z = z.unsqueeze(2) x = super().decode(z).sample.squeeze(2) return x def preprocess(self, x: torch.Tensor): # x should in [B, C, T, H, W], [B, C, H, W] assert x.ndim == 4 or x.size(2) % 4 == 1 return x def postprocess(self, x: torch.Tensor): # x should in [B, C, T, H, W], [B, C, H, W] return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): # TODO #set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))