This commit is contained in:
Yousef Rafat 2025-12-12 20:51:40 +02:00
parent d629c8f910
commit 768c9cedf8
3 changed files with 136 additions and 81 deletions

View File

@ -1,8 +1,10 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops
from einops import rearrange, einsum, repeat
from einops import rearrange, repeat
import comfy.model_management
from torch import nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F
from math import ceil, pi
import torch
@ -559,6 +561,8 @@ class MMModule(nn.Module):
torch.FloatTensor,
]:
vid_module = self.vid if not self.shared_weights else self.all
device = comfy.model_management.get_torch_device()
vid = vid.to(device)
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
if not self.vid_only:
txt_module = self.txt if not self.shared_weights else self.all
@ -616,58 +620,8 @@ class NaMMAttention(nn.Module):
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
vid_q, vid_k, vid_v = vid_qkv.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
if self.rope:
if self.rope.mm:
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache)
vid_len = cache("vid_len", lambda: vid_shape.prod(-1))
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
all_len = cache("all_len", lambda: vid_len + txt_len)
b = len(vid_len)
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
q = torch.cat([vq, tq], dim=1)
k = torch.cat([vk, tk], dim=1)
v = torch.cat([vv, tv], dim=1)
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
attn = attn.flatten(0, 1) # to continue working with the rest of the code
attn = rearrange(attn, "l h d -> l (h d)")
vid_out, txt_out = unconcat(attn)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
def forward(self):
pass
def window(
hid: torch.FloatTensor, # (L c)
@ -783,23 +737,78 @@ class NaSwinAttention(NaMMAttention):
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
# TODO: continue testing
b = len(vid_len_win)
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
v_lens = vid_len_win.cpu().tolist()
t_lens_batch = txt_len.cpu().tolist()
win_counts = window_count.cpu().tolist()
q = torch.cat([vq, tq], dim=1)
k = torch.cat([vk, tk], dim=1)
v = torch.cat([vv, tv], dim=1)
out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True)
out = out.flatten(0, 1)
vq_l = torch.split(vid_q, v_lens)
vk_l = torch.split(vid_k, v_lens)
vv_l = torch.split(vid_v, v_lens)
tv_batch = torch.split(txt_v, t_lens_batch)
tv_l = []
for i, count in enumerate(win_counts):
tv_l.extend([tv_batch[i]] * count)
current_txt_len = txt_q.shape[0]
expected_batch_len = sum(t_lens_batch)
if current_txt_len != expected_batch_len:
t_lens_win = txt_len_win.cpu().tolist()
tq_l = torch.split(txt_q, t_lens_win)
tk_l = torch.split(txt_k, t_lens_win)
else:
tq_batch = torch.split(txt_q, t_lens_batch)
tk_batch = torch.split(txt_k, t_lens_batch)
tq_l = []
tk_l = []
for i, count in enumerate(win_counts):
tq_l.extend([tq_batch[i]] * count)
tk_l.extend([tk_batch[i]] * count)
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
q = rnn_utils.pad_sequence(q_list, batch_first=True)
k = rnn_utils.pad_sequence(k_list, batch_first=True)
v = rnn_utils.pad_sequence(v_list, batch_first=True)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
B, Heads, Max_L, _ = q.shape
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
padding_mask = idx >= len_tensor
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
out = out.transpose(1, 2)
out_flat_list = []
for i, length in enumerate(combined_lens):
out_flat_list.append(out[i, :length])
out = torch.cat(out_flat_list, dim=0)
# text pooling
vid_out, txt_out = unconcat_win(out)
vid_out = rearrange(vid_out, "l h d -> l (h d)")
txt_out = rearrange(txt_out, "l h d -> l (h d)")
vid_out = window_reverse(vid_out)
device = comfy.model_management.get_torch_device()
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
self.proj_out = self.proj_out.to(device)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
@ -837,6 +846,8 @@ class SwiGLUMLP(nn.Module):
self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = x.to(next(self.proj_in.parameters()).device)
self.proj_out = self.proj_out.to(x.device)
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
return x
@ -928,6 +939,7 @@ class NaMMSRTransformerBlock(nn.Module):
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
txt = txt.to(txt_attn.device)
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
@ -967,12 +979,11 @@ class NaPatchOut(PatchOut):
vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), # for test
vid_shape_before_patchify = None
) -> Tuple[
torch.FloatTensor,
torch.LongTensor,
]:
cache = cache.namespace("patch")
vid_shape_before_patchify = cache.get("vid_shape_before_patchify")
t, h, w = self.patch_size
vid = self.proj(vid)
@ -1074,6 +1085,16 @@ class AdaSingle(nn.Module):
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None:
slice_inputs = lambda x, dim: x
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
dim=0,
),
)
shiftA, scaleA, gateA = emb.unbind(-1)
shiftB, scaleB, gateB = (
getattr(self, f"{layer}_shift", None),
@ -1214,8 +1235,8 @@ class NaDiT(nn.Module):
elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__()
self.register_parameter("positive_conditioning", torch.empty((58, 5120)))
self.register_parameter("negative_conditioning", torch.empty((64, 5120)))
self.register_buffer("positive_conditioning", torch.empty((58, 5120)))
self.register_buffer("negative_conditioning", torch.empty((64, 5120)))
self.vid_in = NaPatchIn(
in_channels=vid_in_channels,
patch_size=patch_size,
@ -1306,13 +1327,14 @@ class NaDiT(nn.Module):
x,
timestep,
context, # l c
disable_cache: bool = True, # for test # TODO ?
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
**kwargs
):
transformer_options = kwargs.get("transformer_options", {})
conditions = kwargs.get("condition")
pos_cond, neg_cond = context.chunk(2, dim=0)
pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
pos_cond, txt_shape = flatten([pos_cond])
neg_cond, _ = flatten([neg_cond])
txt = torch.cat([pos_cond, neg_cond], dim = 0)
@ -1331,6 +1353,7 @@ class NaDiT(nn.Module):
vid = vid.to(device).to(dtype)
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
vid_shape_before_patchify = vid_shape
vid, vid_shape = self.vid_in(vid, vid_shape)
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
@ -1358,6 +1381,6 @@ class NaDiT(nn.Module):
branch_tag="vid",
)
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
vid = unflatten(vid, vid_shape)
return vid
return vid[0]

View File

@ -6,9 +6,31 @@ import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
from comfy.ldm.modules.attention import optimized_attention
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
class SpatialNorm(nn.Module):
def __init__(
self,
@ -453,7 +475,7 @@ class Upsample3D(nn.Module):
else:
self.Conv2d_0 = conv
self.norm = False
self.norm = None
def forward(
self,
@ -1255,6 +1277,7 @@ class Decoder3D(nn.Module):
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
@ -1397,10 +1420,10 @@ class VideoAutoencoderKL(nn.Module):
def _decode(
self, z: torch.Tensor
) -> torch.Tensor:
_z = z.to(self.device)
latent = z.to(self.device)
if self.post_quant_conv is not None:
_z = self.post_quant_conv(_z)
output = self.decoder(_z)
latent = self.post_quant_conv(latent)
output = self.decoder(latent)
return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
@ -1473,9 +1496,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return z, p
def decode(self, z: torch.FloatTensor):
latent = z.unsqueeze(0)
scale = 0.9152
shift = 0
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
if z.ndim == 4:
z = z.unsqueeze(2)
x = super().decode(z).sample.squeeze(2)
x = super().decode(latent).squeeze(2)
return x
def preprocess(self, x: torch.Tensor):

View File

@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype)
#cond[:, ..., :-1] = latent_blur[:]
#cond[:, ..., -1:] = 1.0
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
return cond
def timestep_transform(timesteps, latents_shapes):
@ -117,6 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
@classmethod
def execute(cls, images, vae, resolution_height, resolution_width):
device = vae.patcher.load_device
offload_device = vae.patcher.offload_device
vae = vae.first_stage_model
scale = 0.9152; shift = 0
@ -144,6 +145,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
vae = vae.to(device)
images = images.to(device)
latent = vae.encode(images)[0]
vae = vae.to(offload_device)
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
@ -196,8 +198,9 @@ class SeedVR2Conditioning(io.ComfyNode):
else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
negative = [[neg_cond, {"condition": condition}]]
positive = [[pos_cond, {"condition": condition}]]
cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
negative = [[cond, {"condition": condition}]]
positive = [[cond, {"condition": condition}]]
return io.NodeOutput(positive, negative, {"samples": noises})