diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index c1b8a1738..716d728c2 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -16,10 +16,6 @@ from torch.nn.modules.utils import _triple from torch import nn import math import logging -try: - from flash_attn import flash_attn_varlen_func -except: - logging.warning("Best results will be achieved with flash attention enabled for SeedVR2") class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -1299,6 +1295,9 @@ class NaDiT(nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") + b, tc, h, w = x.shape + x = x.view(b, 16, -1, h, w) + conditions = conditions.view(b, 17, -1, h, w) x = x.movedim(1, -1) conditions = conditions.movedim(1, -1) @@ -1375,11 +1374,11 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) out = torch.stack(vid) + out = out.movedim(-1, 1) + out = rearrange(out, "b c t h w -> b (c t) h w") try: pos, neg = out.chunk(2) out = torch.cat([neg, pos]) - out = out.movedim(-1, 1) return out except: - out = out.movedim(-1, 1) return out diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index d3786e85d..a8f8c31f2 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -9,6 +9,7 @@ from torch import Tensor import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention +from comfy_extras.nodes_seedvr import tiled_vae class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): @@ -1450,7 +1451,7 @@ class VideoAutoencoderKL(nn.Module): return posterior - def decode( + def decode_( self, z: torch.Tensor, return_dict: bool = True ): decoded = self.slicing_decode(z) @@ -1541,10 +1542,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = self.decode(z).sample return x, z, p - def encode(self, x, orig_dims): + def encode(self, x, orig_dims=None): # we need to keep a reference to the image/video so we later can do a colour fix later - self.original_image_video = x - self.img_dims = orig_dims + #self.original_image_video = x + if orig_dims is not None: + self.img_dims = orig_dims if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) @@ -1554,6 +1556,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return z, p def decode(self, z: torch.FloatTensor): + b, tc, h, w = z.shape + z = z.view(b, 16, -1, h, w) z = z.movedim(1, -1) latent = z.unsqueeze(0) scale = 0.9152 @@ -1567,12 +1571,31 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): target_device = comfy.model_management.get_torch_device() self.decoder.to(target_device) - x = super().decode(latent).squeeze(2) + x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) + #x = super().decode(latent).squeeze(2) + + input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w") + if x.ndim == 4: + x = x.unsqueeze(0) + + # in case of padded frames + t = input.size(0) + x = x[:, :, :t] + + x = rearrange(x, "b c t h w -> (b t) c h w") - input = rearrange(self.original_image_video[0], "c t h w -> t c h w") x = wavelet_reconstruction(x, input) + x = x.unsqueeze(0) o_h, o_w = self.img_dims x = x[..., :o_h, :o_w] + x = rearrange(x, "b t c h w -> b c t h w") + + # ensure even dims for save video + h, w = x.shape[-2:] + w2 = w - (w % 2) + h2 = h - (h % 2) + x = x[..., :h2, :w2] + return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 2b1d41174..e3281b5f3 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -4,12 +4,146 @@ import torch import math from einops import rearrange +import gc import comfy.model_management +from comfy.utils import ProgressBar + import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode +@torch.inference_mode() +def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): + + gc.collect() + torch.cuda.empty_cache() + + if x.ndim != 5: + x = x.unsqueeze(2) + + b, c, d, h, w = x.shape + + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) + sf_t = getattr(vae_model, "temporal_downsample_factor", 4) + + if encode: + ti_h, ti_w = tile_size + ov_h, ov_w = tile_overlap + ti_t = temporal_size + ov_t = temporal_overlap + + target_d = (d + sf_t - 1) // sf_t + target_h = (h + sf_s - 1) // sf_s + target_w = (w + sf_s - 1) // sf_s + else: + ti_h = max(1, tile_size[0] // sf_s) + ti_w = max(1, tile_size[1] // sf_s) + ov_h = max(0, tile_overlap[0] // sf_s) + ov_w = max(0, tile_overlap[1] // sf_s) + ti_t = max(1, temporal_size // sf_t) + ov_t = max(0, temporal_overlap // sf_t) + + target_d = d * sf_t + target_h = h * sf_s + target_w = w * sf_s + + stride_t = max(1, ti_t - ov_t) + stride_h = max(1, ti_h - ov_h) + stride_w = max(1, ti_w - ov_w) + + storage_device = torch.device("cpu") + result = None + count = None + + ramp_cache = {} + def get_ramp(steps): + if steps not in ramp_cache: + t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) + ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) + return ramp_cache[steps] + + bar = ProgressBar(d // stride_t) + for t_idx in range(0, d, stride_t): + t_end = min(t_idx + ti_t, d) + + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) + + tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end] + + if encode: + tile_out = vae_model.encode(tile_x)[0] + else: + tile_out = vae_model.decode_(tile_x) + + if tile_out.ndim == 4: + tile_out = tile_out.unsqueeze(2) + + tile_out = tile_out.to(storage_device).float() + + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + + if encode: + ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + + cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) + cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) + else: + ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2] + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + + cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2)) + cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) + + w_t = torch.ones((tile_out.shape[2],), device=storage_device) + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) + + if cur_ov_t > 0: + r = get_ramp(cur_ov_t) + if t_idx > 0: w_t[:cur_ov_t] = r + if t_end < d: w_t[-cur_ov_t:] = 1.0 - r + + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: w_h[:cur_ov_h] = r + if y_end < h: w_h[-cur_ov_h:] = 1.0 - r + + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: w_w[:cur_ov_w] = r + if x_end < w: w_w[-cur_ov_w:] = 1.0 - r + + final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) + + tile_out.mul_(final_weight) + result[:, :, ts:te, ys:ye, xs:xe] += tile_out + count[:, :, ts:te, ys:ye, xs:xe] += final_weight + + del tile_out, final_weight, tile_x, w_t, w_h, w_w + bar.update(1) + result.div_(count.clamp(min=1e-6)) + + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + + return result + def expand_dims(tensor, ndim): shape = tensor.shape + (1,) * (ndim - tensor.ndim) return tensor.reshape(shape) @@ -115,7 +249,11 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Image.Input("images"), io.Vae.Input("vae"), io.Int.Input("resolution_height", default = 1280, min = 120), # // - io.Int.Input("resolution_width", default = 720, min = 120) # just non-zero value + io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value + io.Int.Input("spatial_tile_size", default = 512, min = -1), + io.Int.Input("temporal_tile_size", default = 8, min = -1), + io.Int.Input("spatial_overlap", default = 64, min = -1), + io.Int.Input("temporal_overlap", default = 8, min = -1), ], outputs = [ io.Latent.Output("vae_conditioning") @@ -123,7 +261,7 @@ class SeedVR2InputProcessing(io.ComfyNode): ) @classmethod - def execute(cls, images, vae, resolution_height, resolution_width): + def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -155,8 +293,15 @@ class SeedVR2InputProcessing(io.ComfyNode): images = rearrange(images, "b t c h w -> b c t h w") images = images.to(device) vae_model = vae_model.to(device) - latent = vae_model.encode(images, [o_h, o_w])[0] + vae_model.original_image_video = images + + args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), + "temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap} + vae_model.tiled_args = args + latent = tiled_vae(images, vae_model, encode=True, **args) + vae_model = vae_model.to(offload_device) + vae_model.img_dims = [o_h, o_w] latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") @@ -213,6 +358,9 @@ class SeedVR2Conditioning(io.ComfyNode): else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) + noises = rearrange(noises, "b c t h w -> b (c t) h w") + condition = rearrange(condition, "b c t h w -> b (c t) h w") + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]