diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 7444e2823..cbf1383d3 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops -from einops import rearrange, einsum, repeat +from einops import rearrange, repeat +import comfy.model_management from torch import nn +import torch.nn.utils.rnn as rnn_utils import torch.nn.functional as F from math import ceil, pi import torch @@ -559,6 +561,8 @@ class MMModule(nn.Module): torch.FloatTensor, ]: vid_module = self.vid if not self.shared_weights else self.all + device = comfy.model_management.get_torch_device() + vid = vid.to(device) vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) if not self.vid_only: txt_module = self.txt if not self.shared_weights else self.all @@ -616,58 +620,8 @@ class NaMMAttention(nn.Module): self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - if self.rope: - if self.rope.mm: - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) - - vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - all_len = cache("all_len", lambda: vid_len + txt_len) - - b = len(vid_len) - vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] - tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] - - q = torch.cat([vq, tq], dim=1) - k = torch.cat([vk, tk], dim=1) - v = torch.cat([vv, tv], dim=1) - - _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) - - attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True) - attn = attn.flatten(0, 1) # to continue working with the rest of the code - - attn = rearrange(attn, "l h d -> l (h d)") - vid_out, txt_out = unconcat(attn) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - return vid_out, txt_out + def forward(self): + pass def window( hid: torch.FloatTensor, # (L c) @@ -783,23 +737,78 @@ class NaSwinAttention(NaMMAttention): vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) # TODO: continue testing - b = len(vid_len_win) - vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] - tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] + v_lens = vid_len_win.cpu().tolist() + t_lens_batch = txt_len.cpu().tolist() + win_counts = window_count.cpu().tolist() - q = torch.cat([vq, tq], dim=1) - k = torch.cat([vk, tk], dim=1) - v = torch.cat([vv, tv], dim=1) - out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True) - out = out.flatten(0, 1) + vq_l = torch.split(vid_q, v_lens) + vk_l = torch.split(vid_k, v_lens) + vv_l = torch.split(vid_v, v_lens) + + tv_batch = torch.split(txt_v, t_lens_batch) + tv_l = [] + for i, count in enumerate(win_counts): + tv_l.extend([tv_batch[i]] * count) + + current_txt_len = txt_q.shape[0] + expected_batch_len = sum(t_lens_batch) + + if current_txt_len != expected_batch_len: + t_lens_win = txt_len_win.cpu().tolist() + + tq_l = torch.split(txt_q, t_lens_win) + tk_l = torch.split(txt_k, t_lens_win) + else: + tq_batch = torch.split(txt_q, t_lens_batch) + tk_batch = torch.split(txt_k, t_lens_batch) + + tq_l = [] + tk_l = [] + for i, count in enumerate(win_counts): + tq_l.extend([tq_batch[i]] * count) + tk_l.extend([tk_batch[i]] * count) + + q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)] + k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)] + v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)] + + q = rnn_utils.pad_sequence(q_list, batch_first=True) + k = rnn_utils.pad_sequence(k_list, batch_first=True) + v = rnn_utils.pad_sequence(v_list, batch_first=True) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + B, Heads, Max_L, _ = q.shape + combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)] + + attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype) + idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L) + len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1) + + padding_mask = idx >= len_tensor + attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf')) + + out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True) + + out = out.transpose(1, 2) + + out_flat_list = [] + for i, length in enumerate(combined_lens): + out_flat_list.append(out[i, :length]) + + out = torch.cat(out_flat_list, dim=0) - # text pooling vid_out, txt_out = unconcat_win(out) vid_out = rearrange(vid_out, "l h d -> l (h d)") txt_out = rearrange(txt_out, "l h d -> l (h d)") vid_out = window_reverse(vid_out) + device = comfy.model_management.get_torch_device() + vid_out, txt_out = vid_out.to(device), txt_out.to(device) + self.proj_out = self.proj_out.to(device) vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out @@ -837,6 +846,8 @@ class SwiGLUMLP(nn.Module): self.proj_in = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = x.to(next(self.proj_in.parameters()).device) + self.proj_out = self.proj_out.to(x.device) x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) return x @@ -928,6 +939,7 @@ class NaMMSRTransformerBlock(nn.Module): vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + txt = txt.to(txt_attn.device) vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) @@ -967,12 +979,11 @@ class NaPatchOut(PatchOut): vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, cache: Cache = Cache(disable=True), # for test + vid_shape_before_patchify = None ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache.get("vid_shape_before_patchify") t, h, w = self.patch_size vid = self.proj(vid) @@ -1074,6 +1085,16 @@ class AdaSingle(nn.Module): emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) + if hid_len is not None: + slice_inputs = lambda x, dim: x + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), + dim=0, + ), + ) + shiftA, scaleA, gateA = emb.unbind(-1) shiftB, scaleB, gateB = ( getattr(self, f"{layer}_shift", None), @@ -1214,8 +1235,8 @@ class NaDiT(nn.Module): elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() - self.register_parameter("positive_conditioning", torch.empty((58, 5120))) - self.register_parameter("negative_conditioning", torch.empty((64, 5120))) + self.register_buffer("positive_conditioning", torch.empty((58, 5120))) + self.register_buffer("negative_conditioning", torch.empty((64, 5120))) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, @@ -1306,13 +1327,14 @@ class NaDiT(nn.Module): x, timestep, context, # l c - disable_cache: bool = True, # for test # TODO ? + disable_cache: bool = False, # for test # TODO ? // gives an error when set to True **kwargs ): transformer_options = kwargs.get("transformer_options", {}) conditions = kwargs.get("condition") - pos_cond, neg_cond = context.chunk(2, dim=0) + pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0) + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) pos_cond, txt_shape = flatten([pos_cond]) neg_cond, _ = flatten([neg_cond]) txt = torch.cat([pos_cond, neg_cond], dim = 0) @@ -1331,6 +1353,7 @@ class NaDiT(nn.Module): vid = vid.to(device).to(dtype) txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) + vid_shape_before_patchify = vid_shape vid, vid_shape = self.vid_in(vid, vid_shape) emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) @@ -1358,6 +1381,6 @@ class NaDiT(nn.Module): branch_tag="vid", ) - vid, vid_shape = self.vid_out(vid, vid_shape, cache) + vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - return vid + return vid[0] diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 1086f9adc..6c58f044b 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -6,9 +6,31 @@ import torch.nn.functional as F from einops import rearrange from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.modules.attention import optimized_attention +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + sample = torch.randn( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + class SpatialNorm(nn.Module): def __init__( self, @@ -453,7 +475,7 @@ class Upsample3D(nn.Module): else: self.Conv2d_0 = conv - self.norm = False + self.norm = None def forward( self, @@ -1255,6 +1277,7 @@ class Decoder3D(nn.Module): latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + sample = sample.to(next(self.parameters()).device) sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype @@ -1397,10 +1420,10 @@ class VideoAutoencoderKL(nn.Module): def _decode( self, z: torch.Tensor ) -> torch.Tensor: - _z = z.to(self.device) + latent = z.to(self.device) if self.post_quant_conv is not None: - _z = self.post_quant_conv(_z) - output = self.decoder(_z) + latent = self.post_quant_conv(latent) + output = self.decoder(latent) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: @@ -1473,9 +1496,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return z, p def decode(self, z: torch.FloatTensor): + latent = z.unsqueeze(0) + scale = 0.9152 + shift = 0 + latent = latent / scale + shift + latent = rearrange(latent, "b ... c -> b c ...") + latent = latent.squeeze(2) if z.ndim == 4: z = z.unsqueeze(2) - x = super().decode(z).sample.squeeze(2) + x = super().decode(latent).squeeze(2) return x def preprocess(self, x: torch.Tensor): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 8a108f37e..eebcb7dc0 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -15,9 +15,9 @@ def expand_dims(tensor, ndim): def get_conditions(latent, latent_blur): t, h, w, c = latent.shape - cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype) - #cond[:, ..., :-1] = latent_blur[:] - #cond[:, ..., -1:] = 1.0 + cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 return cond def timestep_transform(timesteps, latents_shapes): @@ -117,6 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): device = vae.patcher.load_device + offload_device = vae.patcher.offload_device vae = vae.first_stage_model scale = 0.9152; shift = 0 @@ -144,6 +145,7 @@ class SeedVR2InputProcessing(io.ComfyNode): vae = vae.to(device) images = images.to(device) latent = vae.encode(images)[0] + vae = vae.to(offload_device) latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") @@ -196,8 +198,9 @@ class SeedVR2Conditioning(io.ComfyNode): else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - negative = [[neg_cond, {"condition": condition}]] - positive = [[pos_cond, {"condition": condition}]] + cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0) + negative = [[cond, {"condition": condition}]] + positive = [[cond, {"condition": condition}]] return io.NodeOutput(positive, negative, {"samples": noises})