From d629c8f910593b8aee6f4b03ec8deea6baacfdd8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 12 Dec 2025 00:46:23 +0200 Subject: [PATCH] testing --- comfy/ldm/modules/diffusionmodules/model.py | 6 ++-- comfy/ldm/seedvr/model.py | 26 +++++++++++------ comfy/model_base.py | 3 +- comfy/supported_models.py | 2 +- comfy_extras/nodes_seedvr.py | 32 +++++++++++---------- 5 files changed, 41 insertions(+), 28 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8162742cf..aa37b09bb 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,7 +13,7 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops -def get_timestep_embedding(timesteps, embedding_dim): +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -24,11 +24,13 @@ def get_timestep_embedding(timesteps, embedding_dim): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 98121f26f..7444e2823 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,10 +1,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops -from einops import rearrange, einsum +from einops import rearrange, einsum, repeat from torch import nn import torch.nn.functional as F -from math import ceil, sqrt, pi +from math import ceil, pi import torch from itertools import chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding @@ -12,6 +12,7 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn +import math class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -354,8 +355,8 @@ class RotaryEmbedding(nn.Module): freqs = self.freqs - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) if should_cache and offset == 0: self.cached_freqs[:seq_len] = freqs.detach() @@ -460,6 +461,7 @@ def apply_rotary_emb( t_middle = t[..., start_index:end_index] t_right = t[..., end_index:] + freqs = freqs.to(t_middle.device) t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) out = torch.cat((t_left, t_transformed, t_right), dim=-1) @@ -560,6 +562,7 @@ class MMModule(nn.Module): vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) if not self.vid_only: txt_module = self.txt if not self.shared_weights else self.all + txt = txt.to(device=vid.device, dtype=vid.dtype) txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) return vid, txt @@ -747,6 +750,7 @@ class NaSwinAttention(NaMMAttention): txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len = txt_len.to(window_count.device) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) concat_win, unconcat_win = cache_win( @@ -1122,8 +1126,12 @@ class TimeEmbedding(nn.Module): emb = emb.to(dtype) emb = self.proj_in(emb) emb = self.act(emb) + device = next(self.proj_hid.parameters()).device + emb = emb.to(device) emb = self.proj_hid(emb) emb = self.act(emb) + device = next(self.proj_out.parameters()).device + emb = emb.to(device) emb = self.proj_out(emb) return emb @@ -1206,6 +1214,8 @@ class NaDiT(nn.Module): elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() + self.register_parameter("positive_conditioning", torch.empty((58, 5120))) + self.register_parameter("negative_conditioning", torch.empty((64, 5120))) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, @@ -1303,11 +1313,9 @@ class NaDiT(nn.Module): conditions = kwargs.get("condition") pos_cond, neg_cond = context.chunk(2, dim=0) - # txt_shape should be the same for both - pos_cond, txt_shape = flatten(pos_cond) - neg_cond, _ = flatten(neg_cond) + pos_cond, txt_shape = flatten([pos_cond]) + neg_cond, _ = flatten([neg_cond]) txt = torch.cat([pos_cond, neg_cond], dim = 0) - txt_shape[0] *= 2 vid = x vid, vid_shape = flatten(x) @@ -1321,7 +1329,7 @@ class NaDiT(nn.Module): dtype = next(self.parameters()).dtype txt = txt.to(device).to(dtype) vid = vid.to(device).to(dtype) - txt = self.txt_in(txt) + txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) vid, vid_shape = self.vid_in(vid, vid_shape) diff --git a/comfy/model_base.py b/comfy/model_base.py index f9cc26bfb..f685ba161 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -801,7 +801,8 @@ class SeedVR2(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) condition = kwargs.get("condition", None) - out["condition"] = comfy.conds.CONDRegular(condition) + if condition is not None: + out["condition"] = comfy.conds.CONDRegular(condition) return out class PixArt(BaseModel): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4162a1f5e..1cab38f97 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1168,7 +1168,7 @@ class SeedVR2(supported_models_base.BASE): out = model_base.SeedVR2(self, device=device) return out def clip_target(self, state_dict={}): - return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel) + return None class ACEStep(supported_models_base.BASE): unet_config = { diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 9e8429b66..8a108f37e 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -15,9 +15,9 @@ def expand_dims(tensor, ndim): def get_conditions(latent, latent_blur): t, h, w, c = latent.shape - cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 + cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype) + #cond[:, ..., :-1] = latent_blur[:] + #cond[:, ..., -1:] = 1.0 return cond def timestep_transform(timesteps, latents_shapes): @@ -144,6 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode): vae = vae.to(device) images = images.to(device) latent = vae.encode(images)[0] + latent = latent.unsqueeze(2) if latent.ndim == 4 else latent + latent = rearrange(latent, "b c ... -> b ... c") latent = (latent - shift) * scale @@ -156,9 +158,8 @@ class SeedVR2Conditioning(io.ComfyNode): node_id="SeedVR2Conditioning", category="image/video", inputs=[ - io.Conditioning.Input("text_positive_conditioning"), - io.Conditioning.Input("text_negative_conditioning"), - io.Latent.Input("vae_conditioning") + io.Latent.Input("vae_conditioning"), + io.Model.Input("model"), ], outputs=[io.Conditioning.Output(display_name = "positive"), io.Conditioning.Output(display_name = "negative"), @@ -166,12 +167,13 @@ class SeedVR2Conditioning(io.ComfyNode): ) @classmethod - def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: + def execute(cls, vae_conditioning, model) -> io.NodeOutput: vae_conditioning = vae_conditioning["samples"] device = vae_conditioning.device - pos_cond = text_positive_conditioning[0][0] - neg_cond = text_negative_conditioning[0][0] + model = model.model.diffusion_model + pos_cond = model.positive_conditioning + neg_cond = model.negative_conditioning noises = torch.randn_like(vae_conditioning).to(device) aug_noises = torch.randn_like(vae_conditioning).to(device) @@ -181,21 +183,21 @@ class SeedVR2Conditioning(io.ComfyNode): torch.tensor([1000.0]) * cond_noise_scale ).to(device) - shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] + shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim t = timestep_transform(t, shape) cond = inter(vae_conditioning, aug_noises, t) - condition = get_conditions(noises, cond) + condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)]) - pos_shape = pos_cond.shape[1] - neg_shape = neg_cond.shape[1] + pos_shape = pos_cond.shape[0] + neg_shape = neg_cond.shape[0] diff = abs(pos_shape - neg_shape) if pos_shape > neg_shape: neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - negative = [[pos_cond, {"condition": condition}]] - positive = [[neg_cond, {"condition": condition}]] + negative = [[neg_cond, {"condition": condition}]] + positive = [[pos_cond, {"condition": condition}]] return io.NodeOutput(positive, negative, {"samples": noises})