This commit is contained in:
Yousef Rafat 2025-12-12 00:46:23 +02:00
parent 413ee3f687
commit d629c8f910
5 changed files with 41 additions and 28 deletions

View File

@ -13,7 +13,7 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
@ -24,11 +24,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops
from einops import rearrange, einsum
from einops import rearrange, einsum, repeat
from torch import nn
import torch.nn.functional as F
from math import ceil, sqrt, pi
from math import ceil, pi
import torch
from itertools import chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
@ -12,6 +12,7 @@ from comfy.ldm.modules.attention import optimized_attention
from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
import math
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
@ -354,8 +355,8 @@ class RotaryEmbedding(nn.Module):
freqs = self.freqs
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2)
if should_cache and offset == 0:
self.cached_freqs[:seq_len] = freqs.detach()
@ -460,6 +461,7 @@ def apply_rotary_emb(
t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:]
freqs = freqs.to(t_middle.device)
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
@ -560,6 +562,7 @@ class MMModule(nn.Module):
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
if not self.vid_only:
txt_module = self.txt if not self.shared_weights else self.all
txt = txt.to(device=vid.device, dtype=vid.dtype)
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
return vid, txt
@ -747,6 +750,7 @@ class NaSwinAttention(NaMMAttention):
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len = txt_len.to(window_count.device)
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
concat_win, unconcat_win = cache_win(
@ -1122,8 +1126,12 @@ class TimeEmbedding(nn.Module):
emb = emb.to(dtype)
emb = self.proj_in(emb)
emb = self.act(emb)
device = next(self.proj_hid.parameters()).device
emb = emb.to(device)
emb = self.proj_hid(emb)
emb = self.act(emb)
device = next(self.proj_out.parameters()).device
emb = emb.to(device)
emb = self.proj_out(emb)
return emb
@ -1206,6 +1214,8 @@ class NaDiT(nn.Module):
elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__()
self.register_parameter("positive_conditioning", torch.empty((58, 5120)))
self.register_parameter("negative_conditioning", torch.empty((64, 5120)))
self.vid_in = NaPatchIn(
in_channels=vid_in_channels,
patch_size=patch_size,
@ -1303,11 +1313,9 @@ class NaDiT(nn.Module):
conditions = kwargs.get("condition")
pos_cond, neg_cond = context.chunk(2, dim=0)
# txt_shape should be the same for both
pos_cond, txt_shape = flatten(pos_cond)
neg_cond, _ = flatten(neg_cond)
pos_cond, txt_shape = flatten([pos_cond])
neg_cond, _ = flatten([neg_cond])
txt = torch.cat([pos_cond, neg_cond], dim = 0)
txt_shape[0] *= 2
vid = x
vid, vid_shape = flatten(x)
@ -1321,7 +1329,7 @@ class NaDiT(nn.Module):
dtype = next(self.parameters()).dtype
txt = txt.to(device).to(dtype)
vid = vid.to(device).to(dtype)
txt = self.txt_in(txt)
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
vid, vid_shape = self.vid_in(vid, vid_shape)

View File

@ -801,7 +801,8 @@ class SeedVR2(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
out["condition"] = comfy.conds.CONDRegular(condition)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel):

View File

@ -1168,7 +1168,7 @@ class SeedVR2(supported_models_base.BASE):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel)
return None
class ACEStep(supported_models_base.BASE):
unet_config = {

View File

@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype)
#cond[:, ..., :-1] = latent_blur[:]
#cond[:, ..., -1:] = 1.0
return cond
def timestep_transform(timesteps, latents_shapes):
@ -144,6 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
vae = vae.to(device)
images = images.to(device)
latent = vae.encode(images)[0]
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
@ -156,9 +158,8 @@ class SeedVR2Conditioning(io.ComfyNode):
node_id="SeedVR2Conditioning",
category="image/video",
inputs=[
io.Conditioning.Input("text_positive_conditioning"),
io.Conditioning.Input("text_negative_conditioning"),
io.Latent.Input("vae_conditioning")
io.Latent.Input("vae_conditioning"),
io.Model.Input("model"),
],
outputs=[io.Conditioning.Output(display_name = "positive"),
io.Conditioning.Output(display_name = "negative"),
@ -166,12 +167,13 @@ class SeedVR2Conditioning(io.ComfyNode):
)
@classmethod
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"]
device = vae_conditioning.device
pos_cond = text_positive_conditioning[0][0]
neg_cond = text_negative_conditioning[0][0]
model = model.model.diffusion_model
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = torch.randn_like(vae_conditioning).to(device)
@ -181,21 +183,21 @@ class SeedVR2Conditioning(io.ComfyNode):
torch.tensor([1000.0])
* cond_noise_scale
).to(device)
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None]
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
t = timestep_transform(t, shape)
cond = inter(vae_conditioning, aug_noises, t)
condition = get_conditions(noises, cond)
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
pos_shape = pos_cond.shape[1]
neg_shape = neg_cond.shape[1]
pos_shape = pos_cond.shape[0]
neg_shape = neg_cond.shape[0]
diff = abs(pos_shape - neg_shape)
if pos_shape > neg_shape:
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
negative = [[pos_cond, {"condition": condition}]]
positive = [[neg_cond, {"condition": condition}]]
negative = [[neg_cond, {"condition": condition}]]
positive = [[pos_cond, {"condition": condition}]]
return io.NodeOutput(positive, negative, {"samples": noises})