mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
tile edge case handles by padding vid
This commit is contained in:
parent
9b573da39b
commit
3039c7ba14
@ -59,19 +59,37 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
input_chunk = temporal_size
|
input_chunk = temporal_size
|
||||||
else:
|
else:
|
||||||
input_chunk = max(1, temporal_size // sf_t)
|
input_chunk = max(1, temporal_size // sf_t)
|
||||||
|
|
||||||
for i in range(0, t_dim_size, input_chunk):
|
for i in range(0, t_dim_size, input_chunk):
|
||||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||||
|
current_valid_len = t_chunk.shape[2]
|
||||||
|
|
||||||
|
pad_amount = 0
|
||||||
|
if current_valid_len < input_chunk:
|
||||||
|
pad_amount = input_chunk - current_valid_len
|
||||||
|
|
||||||
|
last_frame = t_chunk[:, :, -1:, :, :]
|
||||||
|
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
||||||
|
|
||||||
|
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
||||||
|
t_chunk = t_chunk.contiguous()
|
||||||
|
|
||||||
if encode:
|
if encode:
|
||||||
out = vae_model.encode(t_chunk)
|
out = vae_model.encode(t_chunk)[0]
|
||||||
else:
|
else:
|
||||||
out = vae_model.decode_(t_chunk)
|
out = vae_model.decode_(t_chunk)
|
||||||
|
|
||||||
if isinstance(out, (tuple, list)): out = out[0]
|
if isinstance(out, (tuple, list)): out = out[0]
|
||||||
|
|
||||||
if out.ndim == 4: out = out.unsqueeze(2)
|
if out.ndim == 4: out = out.unsqueeze(2)
|
||||||
|
|
||||||
|
if pad_amount > 0:
|
||||||
|
if encode:
|
||||||
|
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
||||||
|
out = out[:, :, :expected_valid_out, :, :]
|
||||||
|
|
||||||
|
else:
|
||||||
|
expected_valid_out = current_valid_len * sf_t
|
||||||
|
out = out[:, :, :expected_valid_out, :, :]
|
||||||
|
|
||||||
chunk_results.append(out.to(storage_device))
|
chunk_results.append(out.to(storage_device))
|
||||||
|
|
||||||
return torch.cat(chunk_results, dim=2)
|
return torch.cat(chunk_results, dim=2)
|
||||||
@ -149,15 +167,46 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
|
||||||
|
t = videos.size(temporal_dim)
|
||||||
|
|
||||||
|
if count == 0 and not prepend:
|
||||||
|
if t % 4 == 1:
|
||||||
|
return videos
|
||||||
|
count = ((t - 1) // 4 + 1) * 4 + 1 - t
|
||||||
|
|
||||||
|
if count <= 0:
|
||||||
|
return videos
|
||||||
|
|
||||||
|
def select(start, end):
|
||||||
|
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
|
||||||
|
|
||||||
|
if count >= t:
|
||||||
|
repeat_count = count - t + 1
|
||||||
|
last = select(-1, None)
|
||||||
|
|
||||||
|
if temporal_dim == 0:
|
||||||
|
repeated = last.repeat(repeat_count, 1, 1, 1)
|
||||||
|
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
|
||||||
|
else:
|
||||||
|
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
|
||||||
|
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
|
||||||
|
|
||||||
|
return torch.cat([repeated, reversed_frames, videos] if prepend else
|
||||||
|
[videos, reversed_frames, repeated], dim=temporal_dim)
|
||||||
|
|
||||||
|
if prepend:
|
||||||
|
reversed_frames = select(1, count+1).flip(temporal_dim)
|
||||||
|
else:
|
||||||
|
reversed_frames = select(-count-1, -1).flip(temporal_dim)
|
||||||
|
|
||||||
|
return torch.cat([reversed_frames, videos] if prepend else
|
||||||
|
[videos, reversed_frames], dim=temporal_dim)
|
||||||
|
|
||||||
def clear_vae_memory(vae_model):
|
def clear_vae_memory(vae_model):
|
||||||
for module in vae_model.modules():
|
for module in vae_model.modules():
|
||||||
if hasattr(module, "memory"):
|
if hasattr(module, "memory"):
|
||||||
module.memory = None
|
module.memory = None
|
||||||
if hasattr(vae_model, "original_image_video"):
|
|
||||||
del vae_model.original_image_video
|
|
||||||
|
|
||||||
if hasattr(vae_model, "tiled_args"):
|
|
||||||
del vae_model.tiled_args
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -273,7 +322,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||||
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||||
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||||
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
|
||||||
io.Boolean.Input("enable_tiling", default=False),
|
io.Boolean.Input("enable_tiling", default=False),
|
||||||
],
|
],
|
||||||
outputs = [
|
outputs = [
|
||||||
@ -318,7 +367,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
def make_divisible(val, divisor):
|
def make_divisible(val, divisor):
|
||||||
return max(divisor, round(val / divisor) * divisor)
|
return max(divisor, round(val / divisor) * divisor)
|
||||||
|
|
||||||
temporal_tile_size = make_divisible(temporal_tile_size, 4)
|
|
||||||
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||||
spatial_overlap = make_divisible(spatial_overlap, 32)
|
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user