mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
video works
This commit is contained in:
parent
4fe772fae9
commit
a4e9d071e8
@ -16,10 +16,6 @@ from torch.nn.modules.utils import _triple
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
except:
|
|
||||||
logging.warning("Best results will be achieved with flash attention enabled for SeedVR2")
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self, disable=False, prefix="", cache=None):
|
def __init__(self, disable=False, prefix="", cache=None):
|
||||||
@ -1299,6 +1295,9 @@ class NaDiT(nn.Module):
|
|||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
conditions = kwargs.get("condition")
|
conditions = kwargs.get("condition")
|
||||||
|
b, tc, h, w = x.shape
|
||||||
|
x = x.view(b, 16, -1, h, w)
|
||||||
|
conditions = conditions.view(b, 17, -1, h, w)
|
||||||
x = x.movedim(1, -1)
|
x = x.movedim(1, -1)
|
||||||
conditions = conditions.movedim(1, -1)
|
conditions = conditions.movedim(1, -1)
|
||||||
|
|
||||||
@ -1375,11 +1374,11 @@ class NaDiT(nn.Module):
|
|||||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||||
vid = unflatten(vid, vid_shape)
|
vid = unflatten(vid, vid_shape)
|
||||||
out = torch.stack(vid)
|
out = torch.stack(vid)
|
||||||
|
out = out.movedim(-1, 1)
|
||||||
|
out = rearrange(out, "b c t h w -> b (c t) h w")
|
||||||
try:
|
try:
|
||||||
pos, neg = out.chunk(2)
|
pos, neg = out.chunk(2)
|
||||||
out = torch.cat([neg, pos])
|
out = torch.cat([neg, pos])
|
||||||
out = out.movedim(-1, 1)
|
|
||||||
return out
|
return out
|
||||||
except:
|
except:
|
||||||
out = out.movedim(-1, 1)
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from torch import Tensor
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy_extras.nodes_seedvr import tiled_vae
|
||||||
|
|
||||||
class DiagonalGaussianDistribution(object):
|
class DiagonalGaussianDistribution(object):
|
||||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||||
@ -1450,7 +1451,7 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
|
|
||||||
return posterior
|
return posterior
|
||||||
|
|
||||||
def decode(
|
def decode_(
|
||||||
self, z: torch.Tensor, return_dict: bool = True
|
self, z: torch.Tensor, return_dict: bool = True
|
||||||
):
|
):
|
||||||
decoded = self.slicing_decode(z)
|
decoded = self.slicing_decode(z)
|
||||||
@ -1541,10 +1542,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
x = self.decode(z).sample
|
x = self.decode(z).sample
|
||||||
return x, z, p
|
return x, z, p
|
||||||
|
|
||||||
def encode(self, x, orig_dims):
|
def encode(self, x, orig_dims=None):
|
||||||
# we need to keep a reference to the image/video so we later can do a colour fix later
|
# we need to keep a reference to the image/video so we later can do a colour fix later
|
||||||
self.original_image_video = x
|
#self.original_image_video = x
|
||||||
self.img_dims = orig_dims
|
if orig_dims is not None:
|
||||||
|
self.img_dims = orig_dims
|
||||||
if x.ndim == 4:
|
if x.ndim == 4:
|
||||||
x = x.unsqueeze(2)
|
x = x.unsqueeze(2)
|
||||||
x = x.to(next(self.parameters()).dtype)
|
x = x.to(next(self.parameters()).dtype)
|
||||||
@ -1554,6 +1556,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
return z, p
|
return z, p
|
||||||
|
|
||||||
def decode(self, z: torch.FloatTensor):
|
def decode(self, z: torch.FloatTensor):
|
||||||
|
b, tc, h, w = z.shape
|
||||||
|
z = z.view(b, 16, -1, h, w)
|
||||||
z = z.movedim(1, -1)
|
z = z.movedim(1, -1)
|
||||||
latent = z.unsqueeze(0)
|
latent = z.unsqueeze(0)
|
||||||
scale = 0.9152
|
scale = 0.9152
|
||||||
@ -1567,12 +1571,31 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
|
|
||||||
target_device = comfy.model_management.get_torch_device()
|
target_device = comfy.model_management.get_torch_device()
|
||||||
self.decoder.to(target_device)
|
self.decoder.to(target_device)
|
||||||
x = super().decode(latent).squeeze(2)
|
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
|
||||||
|
#x = super().decode(latent).squeeze(2)
|
||||||
|
|
||||||
|
input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w")
|
||||||
|
if x.ndim == 4:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
# in case of padded frames
|
||||||
|
t = input.size(0)
|
||||||
|
x = x[:, :, :t]
|
||||||
|
|
||||||
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
|
||||||
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
|
||||||
x = wavelet_reconstruction(x, input)
|
x = wavelet_reconstruction(x, input)
|
||||||
|
x = x.unsqueeze(0)
|
||||||
o_h, o_w = self.img_dims
|
o_h, o_w = self.img_dims
|
||||||
x = x[..., :o_h, :o_w]
|
x = x[..., :o_h, :o_w]
|
||||||
|
x = rearrange(x, "b t c h w -> b c t h w")
|
||||||
|
|
||||||
|
# ensure even dims for save video
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
w2 = w - (w % 2)
|
||||||
|
h2 = h - (h % 2)
|
||||||
|
x = x[..., :h2, :w2]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
||||||
|
|||||||
@ -4,12 +4,146 @@ import torch
|
|||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
import gc
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torchvision.transforms import functional as TVF
|
from torchvision.transforms import functional as TVF
|
||||||
from torchvision.transforms import Lambda, Normalize
|
from torchvision.transforms import Lambda, Normalize
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if x.ndim != 5:
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
|
||||||
|
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||||
|
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
ti_h, ti_w = tile_size
|
||||||
|
ov_h, ov_w = tile_overlap
|
||||||
|
ti_t = temporal_size
|
||||||
|
ov_t = temporal_overlap
|
||||||
|
|
||||||
|
target_d = (d + sf_t - 1) // sf_t
|
||||||
|
target_h = (h + sf_s - 1) // sf_s
|
||||||
|
target_w = (w + sf_s - 1) // sf_s
|
||||||
|
else:
|
||||||
|
ti_h = max(1, tile_size[0] // sf_s)
|
||||||
|
ti_w = max(1, tile_size[1] // sf_s)
|
||||||
|
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||||
|
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||||
|
ti_t = max(1, temporal_size // sf_t)
|
||||||
|
ov_t = max(0, temporal_overlap // sf_t)
|
||||||
|
|
||||||
|
target_d = d * sf_t
|
||||||
|
target_h = h * sf_s
|
||||||
|
target_w = w * sf_s
|
||||||
|
|
||||||
|
stride_t = max(1, ti_t - ov_t)
|
||||||
|
stride_h = max(1, ti_h - ov_h)
|
||||||
|
stride_w = max(1, ti_w - ov_w)
|
||||||
|
|
||||||
|
storage_device = torch.device("cpu")
|
||||||
|
result = None
|
||||||
|
count = None
|
||||||
|
|
||||||
|
ramp_cache = {}
|
||||||
|
def get_ramp(steps):
|
||||||
|
if steps not in ramp_cache:
|
||||||
|
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
|
||||||
|
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||||
|
return ramp_cache[steps]
|
||||||
|
|
||||||
|
bar = ProgressBar(d // stride_t)
|
||||||
|
for t_idx in range(0, d, stride_t):
|
||||||
|
t_end = min(t_idx + ti_t, d)
|
||||||
|
|
||||||
|
for y_idx in range(0, h, stride_h):
|
||||||
|
y_end = min(y_idx + ti_h, h)
|
||||||
|
|
||||||
|
for x_idx in range(0, w, stride_w):
|
||||||
|
x_end = min(x_idx + ti_w, w)
|
||||||
|
|
||||||
|
tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end]
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
tile_out = vae_model.encode(tile_x)[0]
|
||||||
|
else:
|
||||||
|
tile_out = vae_model.decode_(tile_x)
|
||||||
|
|
||||||
|
if tile_out.ndim == 4:
|
||||||
|
tile_out = tile_out.unsqueeze(2)
|
||||||
|
|
||||||
|
tile_out = tile_out.to(storage_device).float()
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||||
|
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
|
||||||
|
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||||
|
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||||
|
|
||||||
|
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
|
||||||
|
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||||
|
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||||
|
else:
|
||||||
|
ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2]
|
||||||
|
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||||
|
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||||
|
|
||||||
|
cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2))
|
||||||
|
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||||
|
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||||
|
|
||||||
|
w_t = torch.ones((tile_out.shape[2],), device=storage_device)
|
||||||
|
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||||
|
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||||
|
|
||||||
|
if cur_ov_t > 0:
|
||||||
|
r = get_ramp(cur_ov_t)
|
||||||
|
if t_idx > 0: w_t[:cur_ov_t] = r
|
||||||
|
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
|
||||||
|
|
||||||
|
if cur_ov_h > 0:
|
||||||
|
r = get_ramp(cur_ov_h)
|
||||||
|
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||||
|
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
|
||||||
|
|
||||||
|
if cur_ov_w > 0:
|
||||||
|
r = get_ramp(cur_ov_w)
|
||||||
|
if x_idx > 0: w_w[:cur_ov_w] = r
|
||||||
|
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
||||||
|
|
||||||
|
final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||||
|
|
||||||
|
tile_out.mul_(final_weight)
|
||||||
|
result[:, :, ts:te, ys:ye, xs:xe] += tile_out
|
||||||
|
count[:, :, ts:te, ys:ye, xs:xe] += final_weight
|
||||||
|
|
||||||
|
del tile_out, final_weight, tile_x, w_t, w_h, w_w
|
||||||
|
bar.update(1)
|
||||||
|
result.div_(count.clamp(min=1e-6))
|
||||||
|
|
||||||
|
if result.device != x.device:
|
||||||
|
result = result.to(x.device).to(x.dtype)
|
||||||
|
|
||||||
|
if x.shape[2] == 1 and sf_t == 1:
|
||||||
|
result = result.squeeze(2)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def expand_dims(tensor, ndim):
|
def expand_dims(tensor, ndim):
|
||||||
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
||||||
return tensor.reshape(shape)
|
return tensor.reshape(shape)
|
||||||
@ -115,7 +249,11 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
io.Image.Input("images"),
|
io.Image.Input("images"),
|
||||||
io.Vae.Input("vae"),
|
io.Vae.Input("vae"),
|
||||||
io.Int.Input("resolution_height", default = 1280, min = 120), # //
|
io.Int.Input("resolution_height", default = 1280, min = 120), # //
|
||||||
io.Int.Input("resolution_width", default = 720, min = 120) # just non-zero value
|
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
|
||||||
|
io.Int.Input("spatial_tile_size", default = 512, min = -1),
|
||||||
|
io.Int.Input("temporal_tile_size", default = 8, min = -1),
|
||||||
|
io.Int.Input("spatial_overlap", default = 64, min = -1),
|
||||||
|
io.Int.Input("temporal_overlap", default = 8, min = -1),
|
||||||
],
|
],
|
||||||
outputs = [
|
outputs = [
|
||||||
io.Latent.Output("vae_conditioning")
|
io.Latent.Output("vae_conditioning")
|
||||||
@ -123,7 +261,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
|
||||||
device = vae.patcher.load_device
|
device = vae.patcher.load_device
|
||||||
|
|
||||||
offload_device = comfy.model_management.intermediate_device()
|
offload_device = comfy.model_management.intermediate_device()
|
||||||
@ -155,8 +293,15 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
images = rearrange(images, "b t c h w -> b c t h w")
|
images = rearrange(images, "b t c h w -> b c t h w")
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
vae_model = vae_model.to(device)
|
vae_model = vae_model.to(device)
|
||||||
latent = vae_model.encode(images, [o_h, o_w])[0]
|
vae_model.original_image_video = images
|
||||||
|
|
||||||
|
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||||
|
"temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap}
|
||||||
|
vae_model.tiled_args = args
|
||||||
|
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||||
|
|
||||||
vae_model = vae_model.to(offload_device)
|
vae_model = vae_model.to(offload_device)
|
||||||
|
vae_model.img_dims = [o_h, o_w]
|
||||||
|
|
||||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||||
latent = rearrange(latent, "b c ... -> b ... c")
|
latent = rearrange(latent, "b c ... -> b ... c")
|
||||||
@ -213,6 +358,9 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
|
||||||
|
noises = rearrange(noises, "b c t h w -> b (c t) h w")
|
||||||
|
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||||
|
|
||||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user