From 4b9332cc215a8ab12163908044286ec4fc9bab87 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 7 Dec 2025 21:41:14 +0200 Subject: [PATCH] continue building nodes / testing vae --- comfy/ldm/seedvr/model.py | 40 ++--- comfy/ldm/seedvr/vae.py | 315 +++++++++++++++++++++++++++++++---- comfy_extras/nodes_seedvr.py | 79 ++++++++- 3 files changed, 378 insertions(+), 56 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 40a460d67..cf6287b03 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1141,11 +1141,6 @@ def repeat( kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) -@dataclass -class NaDiTOutput: - vid_sample: torch.Tensor - - class NaDiT(nn.Module): def __init__( @@ -1246,26 +1241,32 @@ class NaDiT(nn.Module): "mmdit_stwin_3d_spatial", ] - def set_gradient_checkpointing(self, enable: bool): - self.gradient_checkpointing = enable - def forward( self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b - disable_cache: bool = True, # for test - ): - # Text input. + x, + timestep, + context, # l c + txt_shape, # b 1 + disable_cache: bool = True, # for test # TODO ? + ): + pos_cond, neg_cond = context.chunk(2, dim=0) + pos_cond, pos_shape = flatten(pos_cond) + neg_cond, neg_shape = flatten(neg_cond) + diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) + if pos_shape.shape[1] > neg_shape.shape[1]: + neg_shape = F.pad(neg_shape, (0, 0, 0, diff)) + neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) + else: + pos_shape = F.pad(pos_shape, (0, 0, 0, diff)) + pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) + vid = x + txt = context + vid, vid_shape = flatten(x) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) # slice vid after patching in when using sequence parallelism txt = self.txt_in(txt) - # Video input. - # Sequence parallel slicing is done inside patching class. vid, vid_shape = self.vid_in(vid, vid_shape) # Embedding input. @@ -1284,4 +1285,5 @@ class NaDiT(nn.Module): ) vid, vid_shape = self.vid_out(vid, vid_shape, cache) - return NaDiTOutput(vid_sample=vid) + vid = unflatten(vid, vid_shape) + return vid diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 51c5b2578..3a0f8cfed 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -4,11 +4,11 @@ import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.attention_processor import Attention -from diffusers.models.upsampling import Upsample2D from einops import rearrange from model import safe_pad_operation from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution +from comfy.ldm.modules.attention import optimized_attention class SpatialNorm(nn.Module): def __init__( @@ -28,6 +28,259 @@ class SpatialNorm(nn.Module): new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f +# partial implementation of diffusers's Attention for comfyui +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + is_causal: bool = False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + self.norm_q = None + self.norm_k = None + + self.norm_cross = None + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + self.norm_added_q = None + self.norm_added_k = None + + def __call__( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + + +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution weight matrix to a 3D one. + Parameters: + weight_2d: The weight matrix of 2D conv to be inflated. + weight_3d: The weight matrix of 3D conv to be initialized. + inflation_mode: the mode of inflation + """ + assert inflation_mode in ["tail", "replicate"] + assert weight_3d.shape[:2] == weight_2d.shape[:2] + with torch.no_grad(): + if inflation_mode == "replicate": + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) + else: + weight_3d.fill_(0.0) + weight_3d[:, :, -1].copy_(weight_2d) + return weight_3d + + +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution bias tensor to a 3D one + Parameters: + bias_2d: The bias tensor of 2D conv to be inflated. + bias_3d: The bias tensor of 3D conv to be initialized. + inflation_mode: Placeholder to align `inflate_weight`. + """ + assert bias_3d.shape == bias_2d.shape + with torch.no_grad(): + bias_3d.copy_(bias_2d) + return bias_3d + + +def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): + """ + the main function to inflated 2D parameters to 3D. + """ + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if weight_2d.dim() == 4: + # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + inflation_mode=layer.inflation_mode, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + # It's a 3d state dict, should not do inflation on both bias and weight. + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: + # Assuming the 2D biases are 1D tensors (out_channels,) + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + inflation_mode=layer.inflation_mode, + ) + state_dict[bias_name] = bias_3d + return state_dict + def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): @@ -131,15 +384,14 @@ class InflatedCausalConv3d(nn.Conv3d): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - # wirdly inflation_mode is pad, which would cause an assert error - #if self.inflation_mode != "none": - # state_dict = modify_state_dict( - # self, - # state_dict, - # prefix, - # inflate_weight_fn=inflate_weight, - # inflate_bias_fn=inflate_bias, - # ) + if self.inflation_mode != "none": + state_dict = modify_state_dict( + self, + state_dict, + prefix, + inflate_weight_fn=inflate_weight, + inflate_bias_fn=inflate_bias, + ) super()._load_from_state_dict( state_dict, prefix, @@ -287,7 +539,10 @@ class Downsample3D(nn.Module): spatial_down: bool = False, temporal_down: bool = False, name: str = "conv", + kernel_size=3, + use_conv: bool = False, padding = 1, + bias=True, **kwargs, ): super().__init__() @@ -295,7 +550,6 @@ class Downsample3D(nn.Module): self.name = name self.channels = channels self.out_channels = out_channels or channels - conv = self.conv self.temporal_down = temporal_down self.spatial_down = spatial_down @@ -305,9 +559,7 @@ class Downsample3D(nn.Module): self.temporal_kernel = 3 if temporal_down else 1 self.spatial_kernel = 3 if spatial_down else 1 - if type(conv) in [nn.Conv2d]: - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. + if use_conv: conv = InflatedCausalConv3d( self.channels, self.out_channels, @@ -320,20 +572,15 @@ class Downsample3D(nn.Module): ), inflation_mode=inflation_mode, ) - elif type(conv) is nn.AvgPool2d: + else: assert self.channels == self.out_channels conv = nn.AvgPool3d( kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), ) - else: - raise NotImplementedError + + self.conv = conv - if self.name == "conv": - self.Conv2d_0 = conv - self.conv = conv - else: - self.conv = conv def forward( self, @@ -386,6 +633,9 @@ class ResnetBlock3D(nn.Module): super().__init__() self.up = up self.down = down + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + conv_2d_out_channels = conv_2d_out_channels or out_channels self.use_in_shortcut = use_in_shortcut self.output_scale_factor = output_scale_factor self.skip_time_act = skip_time_act @@ -394,6 +644,12 @@ class ResnetBlock3D(nn.Module): self.time_emb_proj = nn.Linear(temb_channels, out_channels) else: self.time_emb_proj = None + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + if groups_out is None: + groups_out = groups + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != out_channels + self.dropout = torch.nn.Dropout(dropout) self.conv1 = InflatedCausalConv3d( self.in_channels, self.out_channels, @@ -405,7 +661,7 @@ class ResnetBlock3D(nn.Module): self.conv2 = InflatedCausalConv3d( self.out_channels, - self.conv2.out_channels, + conv_2d_out_channels, kernel_size=3, stride=1, padding=1, @@ -431,11 +687,11 @@ class ResnetBlock3D(nn.Module): if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, - self.conv_shortcut.out_channels, + conv_2d_out_channels, kernel_size=1, stride=1, padding=0, - bias=(self.conv_shortcut.bias is not None), + bias=True, inflation_mode=inflation_mode, ) @@ -534,7 +790,6 @@ class DownEncoderBlock3D(nn.Module): if add_downsample: self.downsamplers = nn.ModuleList( [ - # [Override] Replace module. Downsample3D( out_channels, use_conv=True, @@ -1049,8 +1304,6 @@ class VideoAutoencoderKL(nn.Module): self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), layers_per_block: int = 2, act_fn: str = "silu", latent_channels: int = 16, @@ -1059,7 +1312,7 @@ class VideoAutoencoderKL(nn.Module): temporal_scale_num: int = 2, slicing_up_num: int = 0, gradient_checkpoint: bool = False, - inflation_mode = "tail", + inflation_mode = "pad", time_receptive_field: _receptive_field_t = "full", use_quant_conv: bool = False, use_post_quant_conv: bool = False, @@ -1068,6 +1321,8 @@ class VideoAutoencoderKL(nn.Module): ): extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None block_out_channels = (128, 256, 512, 512) + down_block_types = ("DownEncoderBlock3D",) * 4 + up_block_types = ("UpDecoderBlock3D",) * 4 super().__init__() # pass init params to Encoder @@ -1257,4 +1512,4 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): #set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) \ No newline at end of file + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 60bd551dd..9d4e8bf34 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -1,6 +1,5 @@ - from typing_extensions import override -from comfy_api.latest import ComfyExtension, io, ui +from comfy_api.latest import ComfyExtension, io import torch import math from einops import rearrange @@ -9,7 +8,51 @@ from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode +def expand_dims(tensor, ndim): + shape = tensor.shape + (1,) * (ndim - tensor.ndim) + return tensor.reshape(shape) +def get_conditions(latent, latent_blur): + t, h, w, c = latent.shape + cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + +def timestep_transform(timesteps, latents_shapes): + vt = 4 + vs = 8 + frames = (latents_shapes[:, 0] - 1) * vt + 1 + heights = latents_shapes[:, 1] * vs + widths = latents_shapes[:, 2] * vs + + # Compute shift factor. + def get_lin_function(x1, y1, x2, y2): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) + vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) + shift = torch.where( + frames > 1, + vid_shift_fn(heights * widths * frames), + img_shift_fn(heights * widths), + ) + + # Shift timesteps. + T = 1000.0 + timesteps = timesteps / T + timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) + timesteps = timesteps * T + return timesteps + +def inter(x_0, x_T, t): + t = expand_dims(t, x_0.ndim) + T = 1000.0 + B = lambda t: t / T + A = lambda t: 1 - (t / T) + return A(t) * x_0 + B(t) * x_T def area_resize(image, max_area): height, width = image.shape[-2:] @@ -80,7 +123,7 @@ class SeedVR2InputProcessing(io.ComfyNode): images = normalize(images) images = rearrange(images, "t c h w -> c t h w") images = cut_videos(images) - return + return io.NodeOutput(images) class SeedVR2Conditioning(io.ComfyNode): @classmethod @@ -93,16 +136,38 @@ class SeedVR2Conditioning(io.ComfyNode): io.Conditioning.Input("text_negative_conditioning"), io.Conditioning.Input("vae_conditioning") ], - outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")], + outputs=[io.Conditioning.Output(display_name = "positive"), + io.Conditioning.Output(display_name = "negative"), + io.Latent.Output(display_name = "latent")], ) @classmethod def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: - # TODO + # TODO: should do the flattening logic as with the original code pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] - return io.NodeOutput() + noises = [torch.randn_like(latent) for latent in vae_conditioning] + aug_noises = [torch.randn_like(latent) for latent in vae_conditioning] + + cond_noise_scale = 0.0 + t = ( + torch.tensor([1000.0]) + * cond_noise_scale + ) + shape = torch.tensor(vae_conditioning.shape[1:])[None] + t = timestep_transform(t, shape) + cond = inter(vae_conditioning, aug_noises, t) + condition = get_conditions(noises, cond) + + # TODO / FIXME + pos_cond = torch.cat([condition, pos_cond], dim = 0) + neg_cond = torch.cat([condition, neg_cond], dim = 0) + + negative = [[pos_cond, {}]] + positive = [[neg_cond, {}]] + + return io.NodeOutput(positive, negative, noises) class SeedVRExtension(ComfyExtension): @override @@ -113,4 +178,4 @@ class SeedVRExtension(ComfyExtension): ] async def comfy_entrypoint() -> SeedVRExtension: - return SeedVRExtension() \ No newline at end of file + return SeedVRExtension()