diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index cf3ebd520..7578c0be5 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -491,6 +491,11 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): "mmrope_freqs_3d", lambda: self.get_freqs(vid_shape, txt_shape), ) + target_device = vid_q.device + if vid_freqs.device != target_device: + vid_freqs = vid_freqs.to(target_device) + if txt_freqs.device != target_device: + txt_freqs = txt_freqs.to(target_device) vid_q = rearrange(vid_q, "L h d -> h L d") vid_k = rearrange(vid_k, "L h d -> h L d") vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) @@ -506,6 +511,7 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): txt_k = rearrange(txt_k, "h L d -> L h d") return vid_q, vid_k, txt_q, txt_k + @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks def get_freqs( self, vid_shape: torch.LongTensor, @@ -514,8 +520,29 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): torch.Tensor, torch.Tensor, ]: - vid_freqs = self.get_axial_freqs(1024, 128, 128) - txt_freqs = self.get_axial_freqs(1024) + + # Calculate actual max dimensions needed for this batch + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + + # Compute frequencies for actual max dimensions needed + # Add small buffer to improve cache hits across similar batches + vid_freqs = self.get_axial_freqs( + min(max_temporal + 16, 1024), # Cap at 1024, add small buffer + min(max_height + 4, 128), # Cap at 128, add small buffer + min(max_width + 4, 128) # Cap at 128, add small buffer + ) + txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024)) + + # Now slice as before vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 8380e4feb..e6ccd44c1 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -65,9 +65,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] if encode: - out = vae_model.encode(t_chunk) + out = vae_model.slicing_encode(t_chunk) else: - out = vae_model.decode_(t_chunk) + out = vae_model.slicing_decode(t_chunk) if isinstance(out, (tuple, list)): out = out[0] @@ -245,6 +245,11 @@ def cut_videos(videos): assert (videos.size(1) - 1) % (4) == 0 return videos +def side_resize(image, size): + antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps') + resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias) + return resized + class SeedVR2InputProcessing(io.ComfyNode): @classmethod def define_schema(cls): @@ -285,7 +290,8 @@ class SeedVR2InputProcessing(io.ComfyNode): max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) - images = area_resize(images, max_area) + #images = area_resize(images, max_area) + images = side_resize(images, resolution_height) images = clip(images) o_h, o_w = images.shape[-2:] @@ -348,7 +354,7 @@ class SeedVR2Conditioning(io.ComfyNode): noises = torch.randn_like(vae_conditioning).to(device) aug_noises = torch.randn_like(vae_conditioning).to(device) - + aug_noises = noises * 0.1 + aug_noises * 0.05 cond_noise_scale = 0.0 t = ( torch.tensor([1000.0])