This commit is contained in:
Yousef Rafat 2025-12-12 00:46:23 +02:00
parent 413ee3f687
commit d629c8f910
5 changed files with 41 additions and 28 deletions

View File

@ -13,7 +13,7 @@ if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq. From Fairseq.
@ -24,11 +24,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device) emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :] emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0)) emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb return emb

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops import einops
from einops import rearrange, einsum from einops import rearrange, einsum, repeat
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from math import ceil, sqrt, pi from math import ceil, pi
import torch import torch
from itertools import chain from itertools import chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
@ -12,6 +12,7 @@ from comfy.ldm.modules.attention import optimized_attention
from comfy.rmsnorm import RMSNorm from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple from torch.nn.modules.utils import _triple
from torch import nn from torch import nn
import math
class Cache: class Cache:
def __init__(self, disable=False, prefix="", cache=None): def __init__(self, disable=False, prefix="", cache=None):
@ -354,8 +355,8 @@ class RotaryEmbedding(nn.Module):
freqs = self.freqs freqs = self.freqs
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2) freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2)
if should_cache and offset == 0: if should_cache and offset == 0:
self.cached_freqs[:seq_len] = freqs.detach() self.cached_freqs[:seq_len] = freqs.detach()
@ -460,6 +461,7 @@ def apply_rotary_emb(
t_middle = t[..., start_index:end_index] t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:] t_right = t[..., end_index:]
freqs = freqs.to(t_middle.device)
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
out = torch.cat((t_left, t_transformed, t_right), dim=-1) out = torch.cat((t_left, t_transformed, t_right), dim=-1)
@ -560,6 +562,7 @@ class MMModule(nn.Module):
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
if not self.vid_only: if not self.vid_only:
txt_module = self.txt if not self.shared_weights else self.all txt_module = self.txt if not self.shared_weights else self.all
txt = txt.to(device=vid.device, dtype=vid.dtype)
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
return vid, txt return vid, txt
@ -747,6 +750,7 @@ class NaSwinAttention(NaMMAttention):
txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len = txt_len.to(window_count.device)
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
concat_win, unconcat_win = cache_win( concat_win, unconcat_win = cache_win(
@ -1122,8 +1126,12 @@ class TimeEmbedding(nn.Module):
emb = emb.to(dtype) emb = emb.to(dtype)
emb = self.proj_in(emb) emb = self.proj_in(emb)
emb = self.act(emb) emb = self.act(emb)
device = next(self.proj_hid.parameters()).device
emb = emb.to(device)
emb = self.proj_hid(emb) emb = self.proj_hid(emb)
emb = self.act(emb) emb = self.act(emb)
device = next(self.proj_out.parameters()).device
emb = emb.to(device)
emb = self.proj_out(emb) emb = self.proj_out(emb)
return emb return emb
@ -1206,6 +1214,8 @@ class NaDiT(nn.Module):
elif len(block_type) != num_layers: elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.") raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__() super().__init__()
self.register_parameter("positive_conditioning", torch.empty((58, 5120)))
self.register_parameter("negative_conditioning", torch.empty((64, 5120)))
self.vid_in = NaPatchIn( self.vid_in = NaPatchIn(
in_channels=vid_in_channels, in_channels=vid_in_channels,
patch_size=patch_size, patch_size=patch_size,
@ -1303,11 +1313,9 @@ class NaDiT(nn.Module):
conditions = kwargs.get("condition") conditions = kwargs.get("condition")
pos_cond, neg_cond = context.chunk(2, dim=0) pos_cond, neg_cond = context.chunk(2, dim=0)
# txt_shape should be the same for both pos_cond, txt_shape = flatten([pos_cond])
pos_cond, txt_shape = flatten(pos_cond) neg_cond, _ = flatten([neg_cond])
neg_cond, _ = flatten(neg_cond)
txt = torch.cat([pos_cond, neg_cond], dim = 0) txt = torch.cat([pos_cond, neg_cond], dim = 0)
txt_shape[0] *= 2
vid = x vid = x
vid, vid_shape = flatten(x) vid, vid_shape = flatten(x)
@ -1321,7 +1329,7 @@ class NaDiT(nn.Module):
dtype = next(self.parameters()).dtype dtype = next(self.parameters()).dtype
txt = txt.to(device).to(dtype) txt = txt.to(device).to(dtype)
vid = vid.to(device).to(dtype) vid = vid.to(device).to(dtype)
txt = self.txt_in(txt) txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
vid, vid_shape = self.vid_in(vid, vid_shape) vid, vid_shape = self.vid_in(vid, vid_shape)

View File

@ -801,6 +801,7 @@ class SeedVR2(BaseModel):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None) condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition) out["condition"] = comfy.conds.CONDRegular(condition)
return out return out

View File

@ -1168,7 +1168,7 @@ class SeedVR2(supported_models_base.BASE):
out = model_base.SeedVR2(self, device=device) out = model_base.SeedVR2(self, device=device)
return out return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel) return None
class ACEStep(supported_models_base.BASE): class ACEStep(supported_models_base.BASE):
unet_config = { unet_config = {

View File

@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
def get_conditions(latent, latent_blur): def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape t, h, w, c = latent.shape
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:] #cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0 #cond[:, ..., -1:] = 1.0
return cond return cond
def timestep_transform(timesteps, latents_shapes): def timestep_transform(timesteps, latents_shapes):
@ -144,6 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
vae = vae.to(device) vae = vae.to(device)
images = images.to(device) images = images.to(device)
latent = vae.encode(images)[0] latent = vae.encode(images)[0]
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale latent = (latent - shift) * scale
@ -156,9 +158,8 @@ class SeedVR2Conditioning(io.ComfyNode):
node_id="SeedVR2Conditioning", node_id="SeedVR2Conditioning",
category="image/video", category="image/video",
inputs=[ inputs=[
io.Conditioning.Input("text_positive_conditioning"), io.Latent.Input("vae_conditioning"),
io.Conditioning.Input("text_negative_conditioning"), io.Model.Input("model"),
io.Latent.Input("vae_conditioning")
], ],
outputs=[io.Conditioning.Output(display_name = "positive"), outputs=[io.Conditioning.Output(display_name = "positive"),
io.Conditioning.Output(display_name = "negative"), io.Conditioning.Output(display_name = "negative"),
@ -166,12 +167,13 @@ class SeedVR2Conditioning(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: def execute(cls, vae_conditioning, model) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"] vae_conditioning = vae_conditioning["samples"]
device = vae_conditioning.device device = vae_conditioning.device
pos_cond = text_positive_conditioning[0][0] model = model.model.diffusion_model
neg_cond = text_negative_conditioning[0][0] pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
noises = torch.randn_like(vae_conditioning).to(device) noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = torch.randn_like(vae_conditioning).to(device) aug_noises = torch.randn_like(vae_conditioning).to(device)
@ -181,21 +183,21 @@ class SeedVR2Conditioning(io.ComfyNode):
torch.tensor([1000.0]) torch.tensor([1000.0])
* cond_noise_scale * cond_noise_scale
).to(device) ).to(device)
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
t = timestep_transform(t, shape) t = timestep_transform(t, shape)
cond = inter(vae_conditioning, aug_noises, t) cond = inter(vae_conditioning, aug_noises, t)
condition = get_conditions(noises, cond) condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
pos_shape = pos_cond.shape[1] pos_shape = pos_cond.shape[0]
neg_shape = neg_cond.shape[1] neg_shape = neg_cond.shape[0]
diff = abs(pos_shape - neg_shape) diff = abs(pos_shape - neg_shape)
if pos_shape > neg_shape: if pos_shape > neg_shape:
neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else: else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
negative = [[pos_cond, {"condition": condition}]] negative = [[neg_cond, {"condition": condition}]]
positive = [[neg_cond, {"condition": condition}]] positive = [[pos_cond, {"condition": condition}]]
return io.NodeOutput(positive, negative, {"samples": noises}) return io.NodeOutput(positive, negative, {"samples": noises})