outputs/speed/memory match custom node

This commit is contained in:
Yousef Rafat 2025-12-24 22:15:27 +02:00
parent 1afc2ed8e6
commit 7b2e5ef0af
2 changed files with 39 additions and 6 deletions

View File

@ -491,6 +491,11 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
"mmrope_freqs_3d",
lambda: self.get_freqs(vid_shape, txt_shape),
)
target_device = vid_q.device
if vid_freqs.device != target_device:
vid_freqs = vid_freqs.to(target_device)
if txt_freqs.device != target_device:
txt_freqs = txt_freqs.to(target_device)
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
@ -506,6 +511,7 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
txt_k = rearrange(txt_k, "h L d -> L h d")
return vid_q, vid_k, txt_q, txt_k
@torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks
def get_freqs(
self,
vid_shape: torch.LongTensor,
@ -514,8 +520,29 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
torch.Tensor,
torch.Tensor,
]:
vid_freqs = self.get_axial_freqs(1024, 128, 128)
txt_freqs = self.get_axial_freqs(1024)
# Calculate actual max dimensions needed for this batch
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
# Compute frequencies for actual max dimensions needed
# Add small buffer to improve cache hits across similar batches
vid_freqs = self.get_axial_freqs(
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
min(max_height + 4, 128), # Cap at 128, add small buffer
min(max_width + 4, 128) # Cap at 128, add small buffer
)
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
# Now slice as before
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))

View File

@ -65,9 +65,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
if encode:
out = vae_model.encode(t_chunk)
out = vae_model.slicing_encode(t_chunk)
else:
out = vae_model.decode_(t_chunk)
out = vae_model.slicing_decode(t_chunk)
if isinstance(out, (tuple, list)): out = out[0]
@ -245,6 +245,11 @@ def cut_videos(videos):
assert (videos.size(1) - 1) % (4) == 0
return videos
def side_resize(image, size):
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
return resized
class SeedVR2InputProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -285,7 +290,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
normalize = Normalize(0.5, 0.5)
images = area_resize(images, max_area)
#images = area_resize(images, max_area)
images = side_resize(images, resolution_height)
images = clip(images)
o_h, o_w = images.shape[-2:]
@ -348,7 +354,7 @@ class SeedVR2Conditioning(io.ComfyNode):
noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = noises * 0.1 + aug_noises * 0.05
cond_noise_scale = 0.0
t = (
torch.tensor([1000.0])