tile edge case handles by padding vid

This commit is contained in:
Yousef Rafat 2025-12-26 23:12:45 +02:00
parent 9b573da39b
commit 3039c7ba14

View File

@ -59,19 +59,37 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
input_chunk = temporal_size input_chunk = temporal_size
else: else:
input_chunk = max(1, temporal_size // sf_t) input_chunk = max(1, temporal_size // sf_t)
for i in range(0, t_dim_size, input_chunk): for i in range(0, t_dim_size, input_chunk):
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
current_valid_len = t_chunk.shape[2]
pad_amount = 0
if current_valid_len < input_chunk:
pad_amount = input_chunk - current_valid_len
last_frame = t_chunk[:, :, -1:, :, :]
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
t_chunk = torch.cat([t_chunk, padding], dim=2)
t_chunk = t_chunk.contiguous()
if encode: if encode:
out = vae_model.encode(t_chunk) out = vae_model.encode(t_chunk)[0]
else: else:
out = vae_model.decode_(t_chunk) out = vae_model.decode_(t_chunk)
if isinstance(out, (tuple, list)): out = out[0] if isinstance(out, (tuple, list)): out = out[0]
if out.ndim == 4: out = out.unsqueeze(2) if out.ndim == 4: out = out.unsqueeze(2)
if pad_amount > 0:
if encode:
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
out = out[:, :, :expected_valid_out, :, :]
else:
expected_valid_out = current_valid_len * sf_t
out = out[:, :, :expected_valid_out, :, :]
chunk_results.append(out.to(storage_device)) chunk_results.append(out.to(storage_device))
return torch.cat(chunk_results, dim=2) return torch.cat(chunk_results, dim=2)
@ -149,15 +167,46 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
return result return result
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
t = videos.size(temporal_dim)
if count == 0 and not prepend:
if t % 4 == 1:
return videos
count = ((t - 1) // 4 + 1) * 4 + 1 - t
if count <= 0:
return videos
def select(start, end):
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
if count >= t:
repeat_count = count - t + 1
last = select(-1, None)
if temporal_dim == 0:
repeated = last.repeat(repeat_count, 1, 1, 1)
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
else:
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
return torch.cat([repeated, reversed_frames, videos] if prepend else
[videos, reversed_frames, repeated], dim=temporal_dim)
if prepend:
reversed_frames = select(1, count+1).flip(temporal_dim)
else:
reversed_frames = select(-count-1, -1).flip(temporal_dim)
return torch.cat([reversed_frames, videos] if prepend else
[videos, reversed_frames], dim=temporal_dim)
def clear_vae_memory(vae_model): def clear_vae_memory(vae_model):
for module in vae_model.modules(): for module in vae_model.modules():
if hasattr(module, "memory"): if hasattr(module, "memory"):
module.memory = None module.memory = None
if hasattr(vae_model, "original_image_video"):
del vae_model.original_image_video
if hasattr(vae_model, "tiled_args"):
del vae_model.tiled_args
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -273,7 +322,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
io.Int.Input("spatial_tile_size", default = 512, min = 1), io.Int.Input("spatial_tile_size", default = 512, min = 1),
io.Int.Input("spatial_overlap", default = 64, min = 1), io.Int.Input("spatial_overlap", default = 64, min = 1),
io.Int.Input("temporal_tile_size", default = 8, min = 1), io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
io.Boolean.Input("enable_tiling", default=False), io.Boolean.Input("enable_tiling", default=False),
], ],
outputs = [ outputs = [
@ -318,7 +367,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
def make_divisible(val, divisor): def make_divisible(val, divisor):
return max(divisor, round(val / divisor) * divisor) return max(divisor, round(val / divisor) * divisor)
temporal_tile_size = make_divisible(temporal_tile_size, 4)
spatial_tile_size = make_divisible(spatial_tile_size, 32) spatial_tile_size = make_divisible(spatial_tile_size, 32)
spatial_overlap = make_divisible(spatial_overlap, 32) spatial_overlap = make_divisible(spatial_overlap, 32)