This commit is contained in:
Yousef Rafat 2025-12-12 20:51:40 +02:00
parent d629c8f910
commit 768c9cedf8
3 changed files with 136 additions and 81 deletions

View File

@ -1,8 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops import einops
from einops import rearrange, einsum, repeat from einops import rearrange, repeat
import comfy.model_management
from torch import nn from torch import nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F import torch.nn.functional as F
from math import ceil, pi from math import ceil, pi
import torch import torch
@ -559,6 +561,8 @@ class MMModule(nn.Module):
torch.FloatTensor, torch.FloatTensor,
]: ]:
vid_module = self.vid if not self.shared_weights else self.all vid_module = self.vid if not self.shared_weights else self.all
device = comfy.model_management.get_torch_device()
vid = vid.to(device)
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
if not self.vid_only: if not self.vid_only:
txt_module = self.txt if not self.shared_weights else self.all txt_module = self.txt if not self.shared_weights else self.all
@ -616,58 +620,8 @@ class NaMMAttention(nn.Module):
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
def forward( def forward(self):
self, pass
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
vid_q, vid_k, vid_v = vid_qkv.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
if self.rope:
if self.rope.mm:
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache)
vid_len = cache("vid_len", lambda: vid_shape.prod(-1))
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
all_len = cache("all_len", lambda: vid_len + txt_len)
b = len(vid_len)
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
q = torch.cat([vq, tq], dim=1)
k = torch.cat([vk, tk], dim=1)
v = torch.cat([vv, tv], dim=1)
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
attn = attn.flatten(0, 1) # to continue working with the rest of the code
attn = rearrange(attn, "l h d -> l (h d)")
vid_out, txt_out = unconcat(attn)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
def window( def window(
hid: torch.FloatTensor, # (L c) hid: torch.FloatTensor, # (L c)
@ -783,23 +737,78 @@ class NaSwinAttention(NaMMAttention):
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
# TODO: continue testing # TODO: continue testing
b = len(vid_len_win) v_lens = vid_len_win.cpu().tolist()
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] t_lens_batch = txt_len.cpu().tolist()
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] win_counts = window_count.cpu().tolist()
q = torch.cat([vq, tq], dim=1) vq_l = torch.split(vid_q, v_lens)
k = torch.cat([vk, tk], dim=1) vk_l = torch.split(vid_k, v_lens)
v = torch.cat([vv, tv], dim=1) vv_l = torch.split(vid_v, v_lens)
out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True)
out = out.flatten(0, 1) tv_batch = torch.split(txt_v, t_lens_batch)
tv_l = []
for i, count in enumerate(win_counts):
tv_l.extend([tv_batch[i]] * count)
current_txt_len = txt_q.shape[0]
expected_batch_len = sum(t_lens_batch)
if current_txt_len != expected_batch_len:
t_lens_win = txt_len_win.cpu().tolist()
tq_l = torch.split(txt_q, t_lens_win)
tk_l = torch.split(txt_k, t_lens_win)
else:
tq_batch = torch.split(txt_q, t_lens_batch)
tk_batch = torch.split(txt_k, t_lens_batch)
tq_l = []
tk_l = []
for i, count in enumerate(win_counts):
tq_l.extend([tq_batch[i]] * count)
tk_l.extend([tk_batch[i]] * count)
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
q = rnn_utils.pad_sequence(q_list, batch_first=True)
k = rnn_utils.pad_sequence(k_list, batch_first=True)
v = rnn_utils.pad_sequence(v_list, batch_first=True)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
B, Heads, Max_L, _ = q.shape
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
padding_mask = idx >= len_tensor
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
out = out.transpose(1, 2)
out_flat_list = []
for i, length in enumerate(combined_lens):
out_flat_list.append(out[i, :length])
out = torch.cat(out_flat_list, dim=0)
# text pooling
vid_out, txt_out = unconcat_win(out) vid_out, txt_out = unconcat_win(out)
vid_out = rearrange(vid_out, "l h d -> l (h d)") vid_out = rearrange(vid_out, "l h d -> l (h d)")
txt_out = rearrange(txt_out, "l h d -> l (h d)") txt_out = rearrange(txt_out, "l h d -> l (h d)")
vid_out = window_reverse(vid_out) vid_out = window_reverse(vid_out)
device = comfy.model_management.get_torch_device()
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
self.proj_out = self.proj_out.to(device)
vid_out, txt_out = self.proj_out(vid_out, txt_out) vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out return vid_out, txt_out
@ -837,6 +846,8 @@ class SwiGLUMLP(nn.Module):
self.proj_in = nn.Linear(dim, hidden_dim, bias=False) self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = x.to(next(self.proj_in.parameters()).device)
self.proj_out = self.proj_out.to(x.device)
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
return x return x
@ -928,6 +939,7 @@ class NaMMSRTransformerBlock(nn.Module):
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
txt = txt.to(txt_attn.device)
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
@ -967,12 +979,11 @@ class NaPatchOut(PatchOut):
vid: torch.FloatTensor, # l c vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), # for test cache: Cache = Cache(disable=True), # for test
vid_shape_before_patchify = None
) -> Tuple[ ) -> Tuple[
torch.FloatTensor, torch.FloatTensor,
torch.LongTensor, torch.LongTensor,
]: ]:
cache = cache.namespace("patch")
vid_shape_before_patchify = cache.get("vid_shape_before_patchify")
t, h, w = self.patch_size t, h, w = self.patch_size
vid = self.proj(vid) vid = self.proj(vid)
@ -1074,6 +1085,16 @@ class AdaSingle(nn.Module):
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1) emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None:
slice_inputs = lambda x, dim: x
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
dim=0,
),
)
shiftA, scaleA, gateA = emb.unbind(-1) shiftA, scaleA, gateA = emb.unbind(-1)
shiftB, scaleB, gateB = ( shiftB, scaleB, gateB = (
getattr(self, f"{layer}_shift", None), getattr(self, f"{layer}_shift", None),
@ -1214,8 +1235,8 @@ class NaDiT(nn.Module):
elif len(block_type) != num_layers: elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.") raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__() super().__init__()
self.register_parameter("positive_conditioning", torch.empty((58, 5120))) self.register_buffer("positive_conditioning", torch.empty((58, 5120)))
self.register_parameter("negative_conditioning", torch.empty((64, 5120))) self.register_buffer("negative_conditioning", torch.empty((64, 5120)))
self.vid_in = NaPatchIn( self.vid_in = NaPatchIn(
in_channels=vid_in_channels, in_channels=vid_in_channels,
patch_size=patch_size, patch_size=patch_size,
@ -1306,13 +1327,14 @@ class NaDiT(nn.Module):
x, x,
timestep, timestep,
context, # l c context, # l c
disable_cache: bool = True, # for test # TODO ? disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
**kwargs **kwargs
): ):
transformer_options = kwargs.get("transformer_options", {}) transformer_options = kwargs.get("transformer_options", {})
conditions = kwargs.get("condition") conditions = kwargs.get("condition")
pos_cond, neg_cond = context.chunk(2, dim=0) pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
pos_cond, txt_shape = flatten([pos_cond]) pos_cond, txt_shape = flatten([pos_cond])
neg_cond, _ = flatten([neg_cond]) neg_cond, _ = flatten([neg_cond])
txt = torch.cat([pos_cond, neg_cond], dim = 0) txt = torch.cat([pos_cond, neg_cond], dim = 0)
@ -1331,6 +1353,7 @@ class NaDiT(nn.Module):
vid = vid.to(device).to(dtype) vid = vid.to(device).to(dtype)
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
vid_shape_before_patchify = vid_shape
vid, vid_shape = self.vid_in(vid, vid_shape) vid, vid_shape = self.vid_in(vid, vid_shape)
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
@ -1358,6 +1381,6 @@ class NaDiT(nn.Module):
branch_tag="vid", branch_tag="vid",
) )
vid, vid_shape = self.vid_out(vid, vid_shape, cache) vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
vid = unflatten(vid, vid_shape) vid = unflatten(vid, vid_shape)
return vid return vid[0]

View File

@ -6,9 +6,31 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
class SpatialNorm(nn.Module): class SpatialNorm(nn.Module):
def __init__( def __init__(
self, self,
@ -453,7 +475,7 @@ class Upsample3D(nn.Module):
else: else:
self.Conv2d_0 = conv self.Conv2d_0 = conv
self.norm = False self.norm = None
def forward( def forward(
self, self,
@ -1255,6 +1277,7 @@ class Decoder3D(nn.Module):
latent_embeds: Optional[torch.FloatTensor] = None, latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
@ -1397,10 +1420,10 @@ class VideoAutoencoderKL(nn.Module):
def _decode( def _decode(
self, z: torch.Tensor self, z: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
_z = z.to(self.device) latent = z.to(self.device)
if self.post_quant_conv is not None: if self.post_quant_conv is not None:
_z = self.post_quant_conv(_z) latent = self.post_quant_conv(latent)
output = self.decoder(_z) output = self.decoder(latent)
return output.to(z.device) return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
@ -1473,9 +1496,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return z, p return z, p
def decode(self, z: torch.FloatTensor): def decode(self, z: torch.FloatTensor):
latent = z.unsqueeze(0)
scale = 0.9152
shift = 0
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
if z.ndim == 4: if z.ndim == 4:
z = z.unsqueeze(2) z = z.unsqueeze(2)
x = super().decode(z).sample.squeeze(2) x = super().decode(latent).squeeze(2)
return x return x
def preprocess(self, x: torch.Tensor): def preprocess(self, x: torch.Tensor):

View File

@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
def get_conditions(latent, latent_blur): def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape t, h, w, c = latent.shape
cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype) cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
#cond[:, ..., :-1] = latent_blur[:] cond[:, ..., :-1] = latent_blur[:]
#cond[:, ..., -1:] = 1.0 cond[:, ..., -1:] = 1.0
return cond return cond
def timestep_transform(timesteps, latents_shapes): def timestep_transform(timesteps, latents_shapes):
@ -117,6 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
@classmethod @classmethod
def execute(cls, images, vae, resolution_height, resolution_width): def execute(cls, images, vae, resolution_height, resolution_width):
device = vae.patcher.load_device device = vae.patcher.load_device
offload_device = vae.patcher.offload_device
vae = vae.first_stage_model vae = vae.first_stage_model
scale = 0.9152; shift = 0 scale = 0.9152; shift = 0
@ -144,6 +145,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
vae = vae.to(device) vae = vae.to(device)
images = images.to(device) images = images.to(device)
latent = vae.encode(images)[0] latent = vae.encode(images)[0]
vae = vae.to(offload_device)
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c") latent = rearrange(latent, "b c ... -> b ... c")
@ -196,8 +198,9 @@ class SeedVR2Conditioning(io.ComfyNode):
else: else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
negative = [[neg_cond, {"condition": condition}]] cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
positive = [[pos_cond, {"condition": condition}]] negative = [[cond, {"condition": condition}]]
positive = [[cond, {"condition": condition}]]
return io.NodeOutput(positive, negative, {"samples": noises}) return io.NodeOutput(positive, negative, {"samples": noises})