mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-13 03:52:30 +08:00
.
This commit is contained in:
parent
371c319cf9
commit
e3fa1aa415
@ -340,8 +340,6 @@ class RotaryEmbedding(nn.Module):
|
|||||||
if exists(offsets):
|
if exists(offsets):
|
||||||
assert len(offsets) == len(dims)
|
assert len(offsets) == len(dims)
|
||||||
|
|
||||||
# get frequencies for each axis
|
|
||||||
|
|
||||||
for ind, dim in enumerate(dims):
|
for ind, dim in enumerate(dims):
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
@ -571,14 +569,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
|||||||
max_width = max(max_width, w)
|
max_width = max(max_width, w)
|
||||||
max_txt_len = max(max_txt_len, l)
|
max_txt_len = max(max_txt_len, l)
|
||||||
|
|
||||||
# Compute frequencies for actual max dimensions needed
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
# Add small buffer to improve cache hits across similar batches
|
vid_freqs = self.get_axial_freqs(
|
||||||
vid_freqs = self.get_axial_freqs(
|
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
||||||
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
min(max_height + 4, 128), # Cap at 128, add small buffer
|
||||||
min(max_height + 4, 128), # Cap at 128, add small buffer
|
min(max_width + 4, 128) # Cap at 128, add small buffer
|
||||||
min(max_width + 4, 128) # Cap at 128, add small buffer
|
).float()
|
||||||
)
|
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||||
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
|
||||||
|
|
||||||
# Now slice as before
|
# Now slice as before
|
||||||
vid_freq_list, txt_freq_list = [], []
|
vid_freq_list, txt_freq_list = [], []
|
||||||
@ -1372,13 +1369,6 @@ class NaDiT(nn.Module):
|
|||||||
device=device, dtype=dtype
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stop_cfg_index = -1
|
|
||||||
|
|
||||||
def set_cfg_stop_index(self, cfg):
|
|
||||||
self.stop_cfg_index = cfg
|
|
||||||
|
|
||||||
def get_cfg_stop_index(self):
|
|
||||||
return self.stop_cfg_index
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1397,6 +1387,7 @@ class NaDiT(nn.Module):
|
|||||||
conditions = conditions.view(b, 17, -1, h, w)
|
conditions = conditions.view(b, 17, -1, h, w)
|
||||||
x = x.movedim(1, -1)
|
x = x.movedim(1, -1)
|
||||||
conditions = conditions.movedim(1, -1)
|
conditions = conditions.movedim(1, -1)
|
||||||
|
cache = Cache(disable=disable_cache)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
neg_cond, pos_cond = context.chunk(2, dim=0)
|
neg_cond, pos_cond = context.chunk(2, dim=0)
|
||||||
@ -1420,11 +1411,10 @@ class NaDiT(nn.Module):
|
|||||||
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
||||||
|
|
||||||
vid_shape_before_patchify = vid_shape
|
vid_shape_before_patchify = vid_shape
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache)
|
||||||
|
|
||||||
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
||||||
|
|
||||||
cache = Cache(disable=disable_cache)
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
if ("block", i) in blocks_replace:
|
if ("block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
|
|||||||
@ -423,8 +423,12 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
pos_cond = model.positive_conditioning
|
pos_cond = model.positive_conditioning
|
||||||
neg_cond = model.negative_conditioning
|
neg_cond = model.negative_conditioning
|
||||||
|
|
||||||
noises = torch.randn_like(vae_conditioning).to(device)
|
for module in model.modules():
|
||||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||||
|
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
|
||||||
|
|
||||||
|
noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
|
||||||
|
aug_noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
|
||||||
aug_noises = noises * 0.1 + aug_noises * 0.05
|
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||||
cond_noise_scale = latent_noise_scale
|
cond_noise_scale = latent_noise_scale
|
||||||
t = (
|
t = (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user