mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
.
This commit is contained in:
parent
371c319cf9
commit
e3fa1aa415
@ -340,8 +340,6 @@ class RotaryEmbedding(nn.Module):
|
||||
if exists(offsets):
|
||||
assert len(offsets) == len(dims)
|
||||
|
||||
# get frequencies for each axis
|
||||
|
||||
for ind, dim in enumerate(dims):
|
||||
|
||||
offset = 0
|
||||
@ -571,14 +569,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
max_width = max(max_width, w)
|
||||
max_txt_len = max(max_txt_len, l)
|
||||
|
||||
# Compute frequencies for actual max dimensions needed
|
||||
# Add small buffer to improve cache hits across similar batches
|
||||
vid_freqs = self.get_axial_freqs(
|
||||
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
||||
min(max_height + 4, 128), # Cap at 128, add small buffer
|
||||
min(max_width + 4, 128) # Cap at 128, add small buffer
|
||||
)
|
||||
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
vid_freqs = self.get_axial_freqs(
|
||||
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
||||
min(max_height + 4, 128), # Cap at 128, add small buffer
|
||||
min(max_width + 4, 128) # Cap at 128, add small buffer
|
||||
).float()
|
||||
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||
|
||||
# Now slice as before
|
||||
vid_freq_list, txt_freq_list = [], []
|
||||
@ -1372,13 +1369,6 @@ class NaDiT(nn.Module):
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
self.stop_cfg_index = -1
|
||||
|
||||
def set_cfg_stop_index(self, cfg):
|
||||
self.stop_cfg_index = cfg
|
||||
|
||||
def get_cfg_stop_index(self):
|
||||
return self.stop_cfg_index
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1397,6 +1387,7 @@ class NaDiT(nn.Module):
|
||||
conditions = conditions.view(b, 17, -1, h, w)
|
||||
x = x.movedim(1, -1)
|
||||
conditions = conditions.movedim(1, -1)
|
||||
cache = Cache(disable=disable_cache)
|
||||
|
||||
try:
|
||||
neg_cond, pos_cond = context.chunk(2, dim=0)
|
||||
@ -1420,11 +1411,10 @@ class NaDiT(nn.Module):
|
||||
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
||||
|
||||
vid_shape_before_patchify = vid_shape
|
||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||
vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache)
|
||||
|
||||
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
||||
|
||||
cache = Cache(disable=disable_cache)
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
|
||||
@ -423,8 +423,12 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
pos_cond = model.positive_conditioning
|
||||
neg_cond = model.negative_conditioning
|
||||
|
||||
noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
|
||||
|
||||
noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
|
||||
aug_noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
|
||||
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||
cond_noise_scale = latent_noise_scale
|
||||
t = (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user