mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 17:32:40 +08:00
pr fixes
This commit is contained in:
parent
553f71aa9e
commit
afa38ba172
@ -823,10 +823,6 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
txt_out = rearrange(txt_out, "l h d -> l (h d)")
|
txt_out = rearrange(txt_out, "l h d -> l (h d)")
|
||||||
vid_out = window_reverse(vid_out)
|
vid_out = window_reverse(vid_out)
|
||||||
|
|
||||||
device = comfy.model_management.get_torch_device()
|
|
||||||
dtype = next(self.proj_out.parameters()).dtype
|
|
||||||
vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype)
|
|
||||||
self.proj_out = self.proj_out.to(device)
|
|
||||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||||
|
|
||||||
return vid_out, txt_out
|
return vid_out, txt_out
|
||||||
@ -866,10 +862,7 @@ class SwiGLUMLP(nn.Module):
|
|||||||
self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
x = x.to(next(self.proj_in.parameters()).device)
|
return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
|
||||||
self.proj_out = self.proj_out.to(x.device)
|
|
||||||
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def get_mlp(mlp_type: Optional[str] = "normal"):
|
def get_mlp(mlp_type: Optional[str] = "normal"):
|
||||||
# 3b and 7b uses different mlp types
|
# 3b and 7b uses different mlp types
|
||||||
@ -965,7 +958,6 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
|
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
|
||||||
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
|
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
|
||||||
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
|
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
|
||||||
txt = txt.to(txt_attn.device)
|
|
||||||
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
|
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
|
||||||
|
|
||||||
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
|
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
|
||||||
@ -1188,16 +1180,11 @@ class TimeEmbedding(nn.Module):
|
|||||||
embedding_dim=self.sinusoidal_dim,
|
embedding_dim=self.sinusoidal_dim,
|
||||||
flip_sin_to_cos=False,
|
flip_sin_to_cos=False,
|
||||||
downscale_freq_shift=0,
|
downscale_freq_shift=0,
|
||||||
)
|
).to(dtype)
|
||||||
emb = emb.to(dtype)
|
|
||||||
emb = self.proj_in(emb)
|
emb = self.proj_in(emb)
|
||||||
emb = self.act(emb)
|
emb = self.act(emb)
|
||||||
device = next(self.proj_hid.parameters()).device
|
|
||||||
emb = emb.to(device)
|
|
||||||
emb = self.proj_hid(emb)
|
emb = self.proj_hid(emb)
|
||||||
emb = self.act(emb)
|
emb = self.act(emb)
|
||||||
device = next(self.proj_out.parameters()).device
|
|
||||||
emb = emb.to(device)
|
|
||||||
emb = self.proj_out(emb)
|
emb = self.proj_out(emb)
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
@ -1412,11 +1399,7 @@ class NaDiT(nn.Module):
|
|||||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
txt = self.txt_in(txt)
|
||||||
dtype = next(self.parameters()).dtype
|
|
||||||
txt = txt.to(device).to(dtype)
|
|
||||||
vid = vid.to(device).to(dtype)
|
|
||||||
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
|
||||||
|
|
||||||
vid_shape_before_patchify = vid_shape
|
vid_shape_before_patchify = vid_shape
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache)
|
vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache)
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Literal, Optional, Tuple
|
from typing import Literal, Optional, Tuple
|
||||||
|
import gc
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy_extras.nodes_seedvr import tiled_vae
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -20,9 +20,168 @@ import logging
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True, **kwargs):
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
x = x.to(next(vae_model.parameters()).dtype)
|
||||||
|
if x.ndim != 5:
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
|
||||||
|
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||||
|
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
ti_h, ti_w = tile_size
|
||||||
|
ov_h, ov_w = tile_overlap
|
||||||
|
target_d = (d + sf_t - 1) // sf_t
|
||||||
|
target_h = (h + sf_s - 1) // sf_s
|
||||||
|
target_w = (w + sf_s - 1) // sf_s
|
||||||
|
else:
|
||||||
|
ti_h = max(1, tile_size[0] // sf_s)
|
||||||
|
ti_w = max(1, tile_size[1] // sf_s)
|
||||||
|
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||||
|
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||||
|
|
||||||
|
target_d = d * sf_t
|
||||||
|
target_h = h * sf_s
|
||||||
|
target_w = w * sf_s
|
||||||
|
|
||||||
|
stride_h = max(1, ti_h - ov_h)
|
||||||
|
stride_w = max(1, ti_w - ov_w)
|
||||||
|
|
||||||
|
storage_device = vae_model.device
|
||||||
|
result = None
|
||||||
|
count = None
|
||||||
|
|
||||||
|
def run_temporal_chunks(spatial_tile):
|
||||||
|
chunk_results = []
|
||||||
|
t_dim_size = spatial_tile.shape[2]
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
input_chunk = temporal_size
|
||||||
|
else:
|
||||||
|
input_chunk = max(1, temporal_size // sf_t)
|
||||||
|
for i in range(0, t_dim_size, input_chunk):
|
||||||
|
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||||
|
current_valid_len = t_chunk.shape[2]
|
||||||
|
|
||||||
|
pad_amount = 0
|
||||||
|
if current_valid_len < input_chunk:
|
||||||
|
pad_amount = input_chunk - current_valid_len
|
||||||
|
|
||||||
|
last_frame = t_chunk[:, :, -1:, :, :]
|
||||||
|
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
||||||
|
|
||||||
|
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
||||||
|
t_chunk = t_chunk.contiguous()
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
out = vae_model.encode(t_chunk)[0]
|
||||||
|
else:
|
||||||
|
out = vae_model.decode_(t_chunk)
|
||||||
|
|
||||||
|
if isinstance(out, (tuple, list)):
|
||||||
|
out = out[0]
|
||||||
|
if out.ndim == 4:
|
||||||
|
out = out.unsqueeze(2)
|
||||||
|
|
||||||
|
if pad_amount > 0:
|
||||||
|
if encode:
|
||||||
|
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
||||||
|
out = out[:, :, :expected_valid_out, :, :]
|
||||||
|
|
||||||
|
else:
|
||||||
|
expected_valid_out = current_valid_len * sf_t
|
||||||
|
out = out[:, :, :expected_valid_out, :, :]
|
||||||
|
|
||||||
|
chunk_results.append(out.to(storage_device))
|
||||||
|
|
||||||
|
return torch.cat(chunk_results, dim=2)
|
||||||
|
|
||||||
|
ramp_cache = {}
|
||||||
|
def get_ramp(steps):
|
||||||
|
if steps not in ramp_cache:
|
||||||
|
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
|
||||||
|
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||||
|
return ramp_cache[steps]
|
||||||
|
|
||||||
|
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
||||||
|
bar = ProgressBar(total_tiles)
|
||||||
|
|
||||||
|
for y_idx in range(0, h, stride_h):
|
||||||
|
y_end = min(y_idx + ti_h, h)
|
||||||
|
|
||||||
|
for x_idx in range(0, w, stride_w):
|
||||||
|
x_end = min(x_idx + ti_w, w)
|
||||||
|
|
||||||
|
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||||
|
|
||||||
|
# Run VAE
|
||||||
|
tile_out = run_temporal_chunks(tile_x)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||||
|
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||||
|
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||||
|
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||||
|
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||||
|
else:
|
||||||
|
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||||
|
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||||
|
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||||
|
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||||
|
|
||||||
|
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||||
|
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||||
|
|
||||||
|
if cur_ov_h > 0:
|
||||||
|
r = get_ramp(cur_ov_h)
|
||||||
|
if y_idx > 0:
|
||||||
|
w_h[:cur_ov_h] = r
|
||||||
|
if y_end < h:
|
||||||
|
w_h[-cur_ov_h:] = 1.0 - r
|
||||||
|
|
||||||
|
if cur_ov_w > 0:
|
||||||
|
r = get_ramp(cur_ov_w)
|
||||||
|
if x_idx > 0:
|
||||||
|
w_w[:cur_ov_w] = r
|
||||||
|
if x_end < w:
|
||||||
|
w_w[-cur_ov_w:] = 1.0 - r
|
||||||
|
|
||||||
|
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||||
|
|
||||||
|
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||||
|
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||||
|
|
||||||
|
tile_out.mul_(final_weight)
|
||||||
|
|
||||||
|
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||||
|
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||||
|
|
||||||
|
del tile_out, final_weight, tile_x, w_h, w_w
|
||||||
|
bar.update(1)
|
||||||
|
|
||||||
|
result.div_(count.clamp(min=1e-6))
|
||||||
|
|
||||||
|
if result.device != x.device:
|
||||||
|
result = result.to(x.device).to(x.dtype)
|
||||||
|
|
||||||
|
if x.shape[2] == 1 and sf_t == 1:
|
||||||
|
result = result.squeeze(2)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
_NORM_LIMIT = float("inf")
|
_NORM_LIMIT = float("inf")
|
||||||
|
|
||||||
|
|
||||||
def get_norm_limit():
|
def get_norm_limit():
|
||||||
return _NORM_LIMIT
|
return _NORM_LIMIT
|
||||||
|
|
||||||
|
|||||||
@ -6,208 +6,12 @@ from einops import rearrange
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.utils import ProgressBar
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torchvision.transforms import functional as TVF
|
from torchvision.transforms import functional as TVF
|
||||||
from torchvision.transforms import Lambda, Normalize
|
from torchvision.transforms import Lambda, Normalize
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
from comfy.ldm.seedvr.vae import tiled_vae
|
||||||
@torch.inference_mode()
|
|
||||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
x = x.to(next(vae_model.parameters()).dtype)
|
|
||||||
if x.ndim != 5:
|
|
||||||
x = x.unsqueeze(2)
|
|
||||||
|
|
||||||
b, c, d, h, w = x.shape
|
|
||||||
|
|
||||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
|
||||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
|
||||||
|
|
||||||
if encode:
|
|
||||||
ti_h, ti_w = tile_size
|
|
||||||
ov_h, ov_w = tile_overlap
|
|
||||||
target_d = (d + sf_t - 1) // sf_t
|
|
||||||
target_h = (h + sf_s - 1) // sf_s
|
|
||||||
target_w = (w + sf_s - 1) // sf_s
|
|
||||||
else:
|
|
||||||
ti_h = max(1, tile_size[0] // sf_s)
|
|
||||||
ti_w = max(1, tile_size[1] // sf_s)
|
|
||||||
ov_h = max(0, tile_overlap[0] // sf_s)
|
|
||||||
ov_w = max(0, tile_overlap[1] // sf_s)
|
|
||||||
|
|
||||||
target_d = d * sf_t
|
|
||||||
target_h = h * sf_s
|
|
||||||
target_w = w * sf_s
|
|
||||||
|
|
||||||
stride_h = max(1, ti_h - ov_h)
|
|
||||||
stride_w = max(1, ti_w - ov_w)
|
|
||||||
|
|
||||||
storage_device = vae_model.device
|
|
||||||
result = None
|
|
||||||
count = None
|
|
||||||
|
|
||||||
def run_temporal_chunks(spatial_tile):
|
|
||||||
chunk_results = []
|
|
||||||
t_dim_size = spatial_tile.shape[2]
|
|
||||||
|
|
||||||
if encode:
|
|
||||||
input_chunk = temporal_size
|
|
||||||
else:
|
|
||||||
input_chunk = max(1, temporal_size // sf_t)
|
|
||||||
for i in range(0, t_dim_size, input_chunk):
|
|
||||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
|
||||||
current_valid_len = t_chunk.shape[2]
|
|
||||||
|
|
||||||
pad_amount = 0
|
|
||||||
if current_valid_len < input_chunk:
|
|
||||||
pad_amount = input_chunk - current_valid_len
|
|
||||||
|
|
||||||
last_frame = t_chunk[:, :, -1:, :, :]
|
|
||||||
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
|
||||||
|
|
||||||
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
|
||||||
t_chunk = t_chunk.contiguous()
|
|
||||||
|
|
||||||
if encode:
|
|
||||||
out = vae_model.encode(t_chunk)[0]
|
|
||||||
else:
|
|
||||||
out = vae_model.decode_(t_chunk)
|
|
||||||
|
|
||||||
if isinstance(out, (tuple, list)):
|
|
||||||
out = out[0]
|
|
||||||
if out.ndim == 4:
|
|
||||||
out = out.unsqueeze(2)
|
|
||||||
|
|
||||||
if pad_amount > 0:
|
|
||||||
if encode:
|
|
||||||
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
|
||||||
out = out[:, :, :expected_valid_out, :, :]
|
|
||||||
|
|
||||||
else:
|
|
||||||
expected_valid_out = current_valid_len * sf_t
|
|
||||||
out = out[:, :, :expected_valid_out, :, :]
|
|
||||||
|
|
||||||
chunk_results.append(out.to(storage_device))
|
|
||||||
|
|
||||||
return torch.cat(chunk_results, dim=2)
|
|
||||||
|
|
||||||
ramp_cache = {}
|
|
||||||
def get_ramp(steps):
|
|
||||||
if steps not in ramp_cache:
|
|
||||||
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
|
|
||||||
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
|
||||||
return ramp_cache[steps]
|
|
||||||
|
|
||||||
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
|
||||||
bar = ProgressBar(total_tiles)
|
|
||||||
|
|
||||||
for y_idx in range(0, h, stride_h):
|
|
||||||
y_end = min(y_idx + ti_h, h)
|
|
||||||
|
|
||||||
for x_idx in range(0, w, stride_w):
|
|
||||||
x_end = min(x_idx + ti_w, w)
|
|
||||||
|
|
||||||
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
|
||||||
|
|
||||||
# Run VAE
|
|
||||||
tile_out = run_temporal_chunks(tile_x)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
|
||||||
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
|
||||||
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
|
||||||
|
|
||||||
if encode:
|
|
||||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
|
||||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
|
||||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
|
||||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
|
||||||
else:
|
|
||||||
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
|
||||||
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
|
||||||
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
|
||||||
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
|
||||||
|
|
||||||
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
|
||||||
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
|
||||||
|
|
||||||
if cur_ov_h > 0:
|
|
||||||
r = get_ramp(cur_ov_h)
|
|
||||||
if y_idx > 0:
|
|
||||||
w_h[:cur_ov_h] = r
|
|
||||||
if y_end < h:
|
|
||||||
w_h[-cur_ov_h:] = 1.0 - r
|
|
||||||
|
|
||||||
if cur_ov_w > 0:
|
|
||||||
r = get_ramp(cur_ov_w)
|
|
||||||
if x_idx > 0:
|
|
||||||
w_w[:cur_ov_w] = r
|
|
||||||
if x_end < w:
|
|
||||||
w_w[-cur_ov_w:] = 1.0 - r
|
|
||||||
|
|
||||||
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
|
||||||
|
|
||||||
valid_d = min(tile_out.shape[2], result.shape[2])
|
|
||||||
tile_out = tile_out[:, :, :valid_d, :, :]
|
|
||||||
|
|
||||||
tile_out.mul_(final_weight)
|
|
||||||
|
|
||||||
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
|
||||||
count[:, :, :, ys:ye, xs:xe] += final_weight
|
|
||||||
|
|
||||||
del tile_out, final_weight, tile_x, w_h, w_w
|
|
||||||
bar.update(1)
|
|
||||||
|
|
||||||
result.div_(count.clamp(min=1e-6))
|
|
||||||
|
|
||||||
if result.device != x.device:
|
|
||||||
result = result.to(x.device).to(x.dtype)
|
|
||||||
|
|
||||||
if x.shape[2] == 1 and sf_t == 1:
|
|
||||||
result = result.squeeze(2)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
|
|
||||||
t = videos.size(temporal_dim)
|
|
||||||
|
|
||||||
if count == 0 and not prepend:
|
|
||||||
if t % 4 == 1:
|
|
||||||
return videos
|
|
||||||
count = ((t - 1) // 4 + 1) * 4 + 1 - t
|
|
||||||
|
|
||||||
if count <= 0:
|
|
||||||
return videos
|
|
||||||
|
|
||||||
def select(start, end):
|
|
||||||
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
|
|
||||||
|
|
||||||
if count >= t:
|
|
||||||
repeat_count = count - t + 1
|
|
||||||
last = select(-1, None)
|
|
||||||
|
|
||||||
if temporal_dim == 0:
|
|
||||||
repeated = last.repeat(repeat_count, 1, 1, 1)
|
|
||||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
|
|
||||||
else:
|
|
||||||
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
|
|
||||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
|
|
||||||
|
|
||||||
return torch.cat([repeated, reversed_frames, videos] if prepend else
|
|
||||||
[videos, reversed_frames, repeated], dim=temporal_dim)
|
|
||||||
|
|
||||||
if prepend:
|
|
||||||
reversed_frames = select(1, count+1).flip(temporal_dim)
|
|
||||||
else:
|
|
||||||
reversed_frames = select(-count-1, -1).flip(temporal_dim)
|
|
||||||
|
|
||||||
return torch.cat([reversed_frames, videos] if prepend else
|
|
||||||
[videos, reversed_frames], dim=temporal_dim)
|
|
||||||
|
|
||||||
def clear_vae_memory(vae_model):
|
def clear_vae_memory(vae_model):
|
||||||
for module in vae_model.modules():
|
for module in vae_model.modules():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user