This commit is contained in:
Yousef R. Gamaleldin 2026-01-22 23:43:01 +02:00
parent 371c319cf9
commit e3fa1aa415
2 changed files with 15 additions and 21 deletions

View File

@ -340,8 +340,6 @@ class RotaryEmbedding(nn.Module):
if exists(offsets):
assert len(offsets) == len(dims)
# get frequencies for each axis
for ind, dim in enumerate(dims):
offset = 0
@ -571,14 +569,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
# Compute frequencies for actual max dimensions needed
# Add small buffer to improve cache hits across similar batches
vid_freqs = self.get_axial_freqs(
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
min(max_height + 4, 128), # Cap at 128, add small buffer
min(max_width + 4, 128) # Cap at 128, add small buffer
)
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
with torch.cuda.amp.autocast(enabled=False):
vid_freqs = self.get_axial_freqs(
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
min(max_height + 4, 128), # Cap at 128, add small buffer
min(max_width + 4, 128) # Cap at 128, add small buffer
).float()
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
# Now slice as before
vid_freq_list, txt_freq_list = [], []
@ -1372,13 +1369,6 @@ class NaDiT(nn.Module):
device=device, dtype=dtype
)
self.stop_cfg_index = -1
def set_cfg_stop_index(self, cfg):
self.stop_cfg_index = cfg
def get_cfg_stop_index(self):
return self.stop_cfg_index
def forward(
self,
@ -1397,6 +1387,7 @@ class NaDiT(nn.Module):
conditions = conditions.view(b, 17, -1, h, w)
x = x.movedim(1, -1)
conditions = conditions.movedim(1, -1)
cache = Cache(disable=disable_cache)
try:
neg_cond, pos_cond = context.chunk(2, dim=0)
@ -1420,11 +1411,10 @@ class NaDiT(nn.Module):
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
vid_shape_before_patchify = vid_shape
vid, vid_shape = self.vid_in(vid, vid_shape)
vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache)
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
cache = Cache(disable=disable_cache)
for i, block in enumerate(self.blocks):
if ("block", i) in blocks_replace:
def block_wrap(args):

View File

@ -423,8 +423,12 @@ class SeedVR2Conditioning(io.ComfyNode):
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = torch.randn_like(vae_conditioning).to(device)
for module in model.modules():
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
aug_noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
aug_noises = noises * 0.1 + aug_noises * 0.05
cond_noise_scale = latent_noise_scale
t = (