From 3039c7ba149435b92b81ec3e6f46d99b6d9aca13 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 23:12:45 +0200 Subject: [PATCH] tile edge case handles by padding vid --- comfy_extras/nodes_seedvr.py | 68 ++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 4ec089dde..314100324 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -59,19 +59,37 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora input_chunk = temporal_size else: input_chunk = max(1, temporal_size // sf_t) - for i in range(0, t_dim_size, input_chunk): t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] + current_valid_len = t_chunk.shape[2] + + pad_amount = 0 + if current_valid_len < input_chunk: + pad_amount = input_chunk - current_valid_len + + last_frame = t_chunk[:, :, -1:, :, :] + padding = last_frame.repeat(1, 1, pad_amount, 1, 1) + + t_chunk = torch.cat([t_chunk, padding], dim=2) + t_chunk = t_chunk.contiguous() if encode: - out = vae_model.encode(t_chunk) + out = vae_model.encode(t_chunk)[0] else: out = vae_model.decode_(t_chunk) if isinstance(out, (tuple, list)): out = out[0] - if out.ndim == 4: out = out.unsqueeze(2) + if pad_amount > 0: + if encode: + expected_valid_out = (current_valid_len + sf_t - 1) // sf_t + out = out[:, :, :expected_valid_out, :, :] + + else: + expected_valid_out = current_valid_len * sf_t + out = out[:, :, :expected_valid_out, :, :] + chunk_results.append(out.to(storage_device)) return torch.cat(chunk_results, dim=2) @@ -149,15 +167,46 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora return result +def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False): + t = videos.size(temporal_dim) + + if count == 0 and not prepend: + if t % 4 == 1: + return videos + count = ((t - 1) // 4 + 1) * 4 + 1 - t + + if count <= 0: + return videos + + def select(start, end): + return videos[start:end] if temporal_dim == 0 else videos[:, start:end] + + if count >= t: + repeat_count = count - t + 1 + last = select(-1, None) + + if temporal_dim == 0: + repeated = last.repeat(repeat_count, 1, 1, 1) + reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0] + else: + repeated = last.expand(-1, repeat_count, -1, -1).contiguous() + reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0] + + return torch.cat([repeated, reversed_frames, videos] if prepend else + [videos, reversed_frames, repeated], dim=temporal_dim) + + if prepend: + reversed_frames = select(1, count+1).flip(temporal_dim) + else: + reversed_frames = select(-count-1, -1).flip(temporal_dim) + + return torch.cat([reversed_frames, videos] if prepend else + [videos, reversed_frames], dim=temporal_dim) + def clear_vae_memory(vae_model): for module in vae_model.modules(): if hasattr(module, "memory"): module.memory = None - if hasattr(vae_model, "original_image_video"): - del vae_model.original_image_video - - if hasattr(vae_model, "tiled_args"): - del vae_model.tiled_args gc.collect() torch.cuda.empty_cache() @@ -273,7 +322,7 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value io.Int.Input("spatial_tile_size", default = 512, min = 1), io.Int.Input("spatial_overlap", default = 64, min = 1), - io.Int.Input("temporal_tile_size", default = 8, min = 1), + io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4), io.Boolean.Input("enable_tiling", default=False), ], outputs = [ @@ -318,7 +367,6 @@ class SeedVR2InputProcessing(io.ComfyNode): def make_divisible(val, divisor): return max(divisor, round(val / divisor) * divisor) - temporal_tile_size = make_divisible(temporal_tile_size, 4) spatial_tile_size = make_divisible(spatial_tile_size, 32) spatial_overlap = make_divisible(spatial_overlap, 32)