mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
tile edge case handles by padding vid
This commit is contained in:
parent
9b573da39b
commit
3039c7ba14
@ -59,19 +59,37 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
input_chunk = temporal_size
|
||||
else:
|
||||
input_chunk = max(1, temporal_size // sf_t)
|
||||
|
||||
for i in range(0, t_dim_size, input_chunk):
|
||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||
current_valid_len = t_chunk.shape[2]
|
||||
|
||||
pad_amount = 0
|
||||
if current_valid_len < input_chunk:
|
||||
pad_amount = input_chunk - current_valid_len
|
||||
|
||||
last_frame = t_chunk[:, :, -1:, :, :]
|
||||
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
||||
|
||||
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
||||
t_chunk = t_chunk.contiguous()
|
||||
|
||||
if encode:
|
||||
out = vae_model.encode(t_chunk)
|
||||
out = vae_model.encode(t_chunk)[0]
|
||||
else:
|
||||
out = vae_model.decode_(t_chunk)
|
||||
|
||||
if isinstance(out, (tuple, list)): out = out[0]
|
||||
|
||||
if out.ndim == 4: out = out.unsqueeze(2)
|
||||
|
||||
if pad_amount > 0:
|
||||
if encode:
|
||||
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
||||
out = out[:, :, :expected_valid_out, :, :]
|
||||
|
||||
else:
|
||||
expected_valid_out = current_valid_len * sf_t
|
||||
out = out[:, :, :expected_valid_out, :, :]
|
||||
|
||||
chunk_results.append(out.to(storage_device))
|
||||
|
||||
return torch.cat(chunk_results, dim=2)
|
||||
@ -149,15 +167,46 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
|
||||
return result
|
||||
|
||||
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
|
||||
t = videos.size(temporal_dim)
|
||||
|
||||
if count == 0 and not prepend:
|
||||
if t % 4 == 1:
|
||||
return videos
|
||||
count = ((t - 1) // 4 + 1) * 4 + 1 - t
|
||||
|
||||
if count <= 0:
|
||||
return videos
|
||||
|
||||
def select(start, end):
|
||||
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
|
||||
|
||||
if count >= t:
|
||||
repeat_count = count - t + 1
|
||||
last = select(-1, None)
|
||||
|
||||
if temporal_dim == 0:
|
||||
repeated = last.repeat(repeat_count, 1, 1, 1)
|
||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
|
||||
else:
|
||||
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
|
||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
|
||||
|
||||
return torch.cat([repeated, reversed_frames, videos] if prepend else
|
||||
[videos, reversed_frames, repeated], dim=temporal_dim)
|
||||
|
||||
if prepend:
|
||||
reversed_frames = select(1, count+1).flip(temporal_dim)
|
||||
else:
|
||||
reversed_frames = select(-count-1, -1).flip(temporal_dim)
|
||||
|
||||
return torch.cat([reversed_frames, videos] if prepend else
|
||||
[videos, reversed_frames], dim=temporal_dim)
|
||||
|
||||
def clear_vae_memory(vae_model):
|
||||
for module in vae_model.modules():
|
||||
if hasattr(module, "memory"):
|
||||
module.memory = None
|
||||
if hasattr(vae_model, "original_image_video"):
|
||||
del vae_model.original_image_video
|
||||
|
||||
if hasattr(vae_model, "tiled_args"):
|
||||
del vae_model.tiled_args
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -273,7 +322,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
||||
io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
|
||||
io.Boolean.Input("enable_tiling", default=False),
|
||||
],
|
||||
outputs = [
|
||||
@ -318,7 +367,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
def make_divisible(val, divisor):
|
||||
return max(divisor, round(val / divisor) * divisor)
|
||||
|
||||
temporal_tile_size = make_divisible(temporal_tile_size, 4)
|
||||
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user