From e3fa1aa415d2d05d1ccd2c3f15f124365c9c85dc Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" Date: Thu, 22 Jan 2026 23:43:01 +0200 Subject: [PATCH] . --- comfy/ldm/seedvr/model.py | 28 +++++++++------------------- comfy_extras/nodes_seedvr.py | 8 ++++++-- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index f42dcb1e2..e7570699e 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -340,8 +340,6 @@ class RotaryEmbedding(nn.Module): if exists(offsets): assert len(offsets) == len(dims) - # get frequencies for each axis - for ind, dim in enumerate(dims): offset = 0 @@ -571,14 +569,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): max_width = max(max_width, w) max_txt_len = max(max_txt_len, l) - # Compute frequencies for actual max dimensions needed - # Add small buffer to improve cache hits across similar batches - vid_freqs = self.get_axial_freqs( - min(max_temporal + 16, 1024), # Cap at 1024, add small buffer - min(max_height + 4, 128), # Cap at 128, add small buffer - min(max_width + 4, 128) # Cap at 128, add small buffer - ) - txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024)) + with torch.cuda.amp.autocast(enabled=False): + vid_freqs = self.get_axial_freqs( + min(max_temporal + 16, 1024), # Cap at 1024, add small buffer + min(max_height + 4, 128), # Cap at 128, add small buffer + min(max_width + 4, 128) # Cap at 128, add small buffer + ).float() + txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024)) # Now slice as before vid_freq_list, txt_freq_list = [], [] @@ -1372,13 +1369,6 @@ class NaDiT(nn.Module): device=device, dtype=dtype ) - self.stop_cfg_index = -1 - - def set_cfg_stop_index(self, cfg): - self.stop_cfg_index = cfg - - def get_cfg_stop_index(self): - return self.stop_cfg_index def forward( self, @@ -1397,6 +1387,7 @@ class NaDiT(nn.Module): conditions = conditions.view(b, 17, -1, h, w) x = x.movedim(1, -1) conditions = conditions.movedim(1, -1) + cache = Cache(disable=disable_cache) try: neg_cond, pos_cond = context.chunk(2, dim=0) @@ -1420,11 +1411,10 @@ class NaDiT(nn.Module): txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) vid_shape_before_patchify = vid_shape - vid, vid_shape = self.vid_in(vid, vid_shape) + vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - cache = Cache(disable=disable_cache) for i, block in enumerate(self.blocks): if ("block", i) in blocks_replace: def block_wrap(args): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index a7dc101fc..a312c58a0 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -423,8 +423,12 @@ class SeedVR2Conditioning(io.ComfyNode): pos_cond = model.positive_conditioning neg_cond = model.negative_conditioning - noises = torch.randn_like(vae_conditioning).to(device) - aug_noises = torch.randn_like(vae_conditioning).to(device) + for module in model.modules(): + if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): + module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) + + noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device) + aug_noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device) aug_noises = noises * 0.1 + aug_noises * 0.05 cond_noise_scale = latent_noise_scale t = (