mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 04:50:49 +08:00
..
This commit is contained in:
parent
d629c8f910
commit
768c9cedf8
@ -1,8 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
||||
import einops
|
||||
from einops import rearrange, einsum, repeat
|
||||
from einops import rearrange, repeat
|
||||
import comfy.model_management
|
||||
from torch import nn
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
import torch.nn.functional as F
|
||||
from math import ceil, pi
|
||||
import torch
|
||||
@ -559,6 +561,8 @@ class MMModule(nn.Module):
|
||||
torch.FloatTensor,
|
||||
]:
|
||||
vid_module = self.vid if not self.shared_weights else self.all
|
||||
device = comfy.model_management.get_torch_device()
|
||||
vid = vid.to(device)
|
||||
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
|
||||
if not self.vid_only:
|
||||
txt_module = self.txt if not self.shared_weights else self.all
|
||||
@ -616,58 +620,8 @@ class NaMMAttention(nn.Module):
|
||||
|
||||
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
vid: torch.FloatTensor, # l c
|
||||
txt: torch.FloatTensor, # l c
|
||||
vid_shape: torch.LongTensor, # b 3
|
||||
txt_shape: torch.LongTensor, # b 1
|
||||
cache: Cache,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
torch.FloatTensor,
|
||||
]:
|
||||
|
||||
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
|
||||
vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
|
||||
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
|
||||
|
||||
vid_q, vid_k, vid_v = vid_qkv.unbind(1)
|
||||
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
|
||||
|
||||
vid_q, txt_q = self.norm_q(vid_q, txt_q)
|
||||
vid_k, txt_k = self.norm_k(vid_k, txt_k)
|
||||
|
||||
if self.rope:
|
||||
if self.rope.mm:
|
||||
vid_q, vid_k, txt_q, txt_k = self.rope(
|
||||
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
|
||||
)
|
||||
else:
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache)
|
||||
|
||||
vid_len = cache("vid_len", lambda: vid_shape.prod(-1))
|
||||
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
|
||||
all_len = cache("all_len", lambda: vid_len + txt_len)
|
||||
|
||||
b = len(vid_len)
|
||||
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
|
||||
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
|
||||
|
||||
q = torch.cat([vq, tq], dim=1)
|
||||
k = torch.cat([vk, tk], dim=1)
|
||||
v = torch.cat([vv, tv], dim=1)
|
||||
|
||||
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
|
||||
|
||||
attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
|
||||
attn = attn.flatten(0, 1) # to continue working with the rest of the code
|
||||
|
||||
attn = rearrange(attn, "l h d -> l (h d)")
|
||||
vid_out, txt_out = unconcat(attn)
|
||||
|
||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||
return vid_out, txt_out
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def window(
|
||||
hid: torch.FloatTensor, # (L c)
|
||||
@ -783,23 +737,78 @@ class NaSwinAttention(NaMMAttention):
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
|
||||
# TODO: continue testing
|
||||
b = len(vid_len_win)
|
||||
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
|
||||
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
|
||||
v_lens = vid_len_win.cpu().tolist()
|
||||
t_lens_batch = txt_len.cpu().tolist()
|
||||
win_counts = window_count.cpu().tolist()
|
||||
|
||||
q = torch.cat([vq, tq], dim=1)
|
||||
k = torch.cat([vk, tk], dim=1)
|
||||
v = torch.cat([vv, tv], dim=1)
|
||||
out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True)
|
||||
out = out.flatten(0, 1)
|
||||
vq_l = torch.split(vid_q, v_lens)
|
||||
vk_l = torch.split(vid_k, v_lens)
|
||||
vv_l = torch.split(vid_v, v_lens)
|
||||
|
||||
tv_batch = torch.split(txt_v, t_lens_batch)
|
||||
tv_l = []
|
||||
for i, count in enumerate(win_counts):
|
||||
tv_l.extend([tv_batch[i]] * count)
|
||||
|
||||
current_txt_len = txt_q.shape[0]
|
||||
expected_batch_len = sum(t_lens_batch)
|
||||
|
||||
if current_txt_len != expected_batch_len:
|
||||
t_lens_win = txt_len_win.cpu().tolist()
|
||||
|
||||
tq_l = torch.split(txt_q, t_lens_win)
|
||||
tk_l = torch.split(txt_k, t_lens_win)
|
||||
else:
|
||||
tq_batch = torch.split(txt_q, t_lens_batch)
|
||||
tk_batch = torch.split(txt_k, t_lens_batch)
|
||||
|
||||
tq_l = []
|
||||
tk_l = []
|
||||
for i, count in enumerate(win_counts):
|
||||
tq_l.extend([tq_batch[i]] * count)
|
||||
tk_l.extend([tk_batch[i]] * count)
|
||||
|
||||
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
|
||||
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
|
||||
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
|
||||
|
||||
q = rnn_utils.pad_sequence(q_list, batch_first=True)
|
||||
k = rnn_utils.pad_sequence(k_list, batch_first=True)
|
||||
v = rnn_utils.pad_sequence(v_list, batch_first=True)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
B, Heads, Max_L, _ = q.shape
|
||||
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
|
||||
|
||||
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
|
||||
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
|
||||
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
|
||||
|
||||
padding_mask = idx >= len_tensor
|
||||
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
|
||||
|
||||
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out_flat_list = []
|
||||
for i, length in enumerate(combined_lens):
|
||||
out_flat_list.append(out[i, :length])
|
||||
|
||||
out = torch.cat(out_flat_list, dim=0)
|
||||
|
||||
# text pooling
|
||||
vid_out, txt_out = unconcat_win(out)
|
||||
|
||||
vid_out = rearrange(vid_out, "l h d -> l (h d)")
|
||||
txt_out = rearrange(txt_out, "l h d -> l (h d)")
|
||||
vid_out = window_reverse(vid_out)
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
|
||||
self.proj_out = self.proj_out.to(device)
|
||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||
|
||||
return vid_out, txt_out
|
||||
@ -837,6 +846,8 @@ class SwiGLUMLP(nn.Module):
|
||||
self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
x = x.to(next(self.proj_in.parameters()).device)
|
||||
self.proj_out = self.proj_out.to(x.device)
|
||||
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
|
||||
return x
|
||||
|
||||
@ -928,6 +939,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
||||
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
|
||||
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
|
||||
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
|
||||
txt = txt.to(txt_attn.device)
|
||||
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
|
||||
|
||||
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
|
||||
@ -967,12 +979,11 @@ class NaPatchOut(PatchOut):
|
||||
vid: torch.FloatTensor, # l c
|
||||
vid_shape: torch.LongTensor,
|
||||
cache: Cache = Cache(disable=True), # for test
|
||||
vid_shape_before_patchify = None
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
torch.LongTensor,
|
||||
]:
|
||||
cache = cache.namespace("patch")
|
||||
vid_shape_before_patchify = cache.get("vid_shape_before_patchify")
|
||||
|
||||
t, h, w = self.patch_size
|
||||
vid = self.proj(vid)
|
||||
@ -1074,6 +1085,16 @@ class AdaSingle(nn.Module):
|
||||
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
|
||||
emb = expand_dims(emb, 1, hid.ndim + 1)
|
||||
|
||||
if hid_len is not None:
|
||||
slice_inputs = lambda x, dim: x
|
||||
emb = cache(
|
||||
f"emb_repeat_{idx}_{branch_tag}",
|
||||
lambda: slice_inputs(
|
||||
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
|
||||
dim=0,
|
||||
),
|
||||
)
|
||||
|
||||
shiftA, scaleA, gateA = emb.unbind(-1)
|
||||
shiftB, scaleB, gateB = (
|
||||
getattr(self, f"{layer}_shift", None),
|
||||
@ -1214,8 +1235,8 @@ class NaDiT(nn.Module):
|
||||
elif len(block_type) != num_layers:
|
||||
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
||||
super().__init__()
|
||||
self.register_parameter("positive_conditioning", torch.empty((58, 5120)))
|
||||
self.register_parameter("negative_conditioning", torch.empty((64, 5120)))
|
||||
self.register_buffer("positive_conditioning", torch.empty((58, 5120)))
|
||||
self.register_buffer("negative_conditioning", torch.empty((64, 5120)))
|
||||
self.vid_in = NaPatchIn(
|
||||
in_channels=vid_in_channels,
|
||||
patch_size=patch_size,
|
||||
@ -1306,13 +1327,14 @@ class NaDiT(nn.Module):
|
||||
x,
|
||||
timestep,
|
||||
context, # l c
|
||||
disable_cache: bool = True, # for test # TODO ?
|
||||
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
|
||||
**kwargs
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
conditions = kwargs.get("condition")
|
||||
|
||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||
pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
|
||||
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||
pos_cond, txt_shape = flatten([pos_cond])
|
||||
neg_cond, _ = flatten([neg_cond])
|
||||
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
||||
@ -1331,6 +1353,7 @@ class NaDiT(nn.Module):
|
||||
vid = vid.to(device).to(dtype)
|
||||
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
||||
|
||||
vid_shape_before_patchify = vid_shape
|
||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||
|
||||
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
||||
@ -1358,6 +1381,6 @@ class NaDiT(nn.Module):
|
||||
branch_tag="vid",
|
||||
)
|
||||
|
||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
|
||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||
vid = unflatten(vid, vid_shape)
|
||||
return vid
|
||||
return vid[0]
|
||||
|
||||
@ -6,9 +6,31 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||
)
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||
sample = torch.randn(
|
||||
self.mean.shape,
|
||||
generator=generator,
|
||||
device=self.parameters.device,
|
||||
dtype=self.parameters.dtype,
|
||||
)
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
|
||||
class SpatialNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -453,7 +475,7 @@ class Upsample3D(nn.Module):
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
self.norm = False
|
||||
self.norm = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1255,6 +1277,7 @@ class Decoder3D(nn.Module):
|
||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
@ -1397,10 +1420,10 @@ class VideoAutoencoderKL(nn.Module):
|
||||
def _decode(
|
||||
self, z: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
_z = z.to(self.device)
|
||||
latent = z.to(self.device)
|
||||
if self.post_quant_conv is not None:
|
||||
_z = self.post_quant_conv(_z)
|
||||
output = self.decoder(_z)
|
||||
latent = self.post_quant_conv(latent)
|
||||
output = self.decoder(latent)
|
||||
return output.to(z.device)
|
||||
|
||||
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1473,9 +1496,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
return z, p
|
||||
|
||||
def decode(self, z: torch.FloatTensor):
|
||||
latent = z.unsqueeze(0)
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
latent = latent / scale + shift
|
||||
latent = rearrange(latent, "b ... c -> b c ...")
|
||||
latent = latent.squeeze(2)
|
||||
if z.ndim == 4:
|
||||
z = z.unsqueeze(2)
|
||||
x = super().decode(z).sample.squeeze(2)
|
||||
x = super().decode(latent).squeeze(2)
|
||||
return x
|
||||
|
||||
def preprocess(self, x: torch.Tensor):
|
||||
|
||||
@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
|
||||
|
||||
def get_conditions(latent, latent_blur):
|
||||
t, h, w, c = latent.shape
|
||||
cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype)
|
||||
#cond[:, ..., :-1] = latent_blur[:]
|
||||
#cond[:, ..., -1:] = 1.0
|
||||
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||
cond[:, ..., :-1] = latent_blur[:]
|
||||
cond[:, ..., -1:] = 1.0
|
||||
return cond
|
||||
|
||||
def timestep_transform(timesteps, latents_shapes):
|
||||
@ -117,6 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
||||
device = vae.patcher.load_device
|
||||
offload_device = vae.patcher.offload_device
|
||||
vae = vae.first_stage_model
|
||||
scale = 0.9152; shift = 0
|
||||
|
||||
@ -144,6 +145,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
vae = vae.to(device)
|
||||
images = images.to(device)
|
||||
latent = vae.encode(images)[0]
|
||||
vae = vae.to(offload_device)
|
||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||
latent = rearrange(latent, "b c ... -> b ... c")
|
||||
|
||||
@ -196,8 +198,9 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
else:
|
||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||
|
||||
negative = [[neg_cond, {"condition": condition}]]
|
||||
positive = [[pos_cond, {"condition": condition}]]
|
||||
cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
|
||||
negative = [[cond, {"condition": condition}]]
|
||||
positive = [[cond, {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user