diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 29b51608e..eab9d83ed 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -823,10 +823,6 @@ class NaSwinAttention(NaMMAttention): txt_out = rearrange(txt_out, "l h d -> l (h d)") vid_out = window_reverse(vid_out) - device = comfy.model_management.get_torch_device() - dtype = next(self.proj_out.parameters()).dtype - vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype) - self.proj_out = self.proj_out.to(device) vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out @@ -866,10 +862,7 @@ class SwiGLUMLP(nn.Module): self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = x.to(next(self.proj_in.parameters()).device) - self.proj_out = self.proj_out.to(x.device) - x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) - return x + return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) def get_mlp(mlp_type: Optional[str] = "normal"): # 3b and 7b uses different mlp types @@ -965,7 +958,6 @@ class NaMMSRTransformerBlock(nn.Module): vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) - txt = txt.to(txt_attn.device) vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) @@ -1188,16 +1180,11 @@ class TimeEmbedding(nn.Module): embedding_dim=self.sinusoidal_dim, flip_sin_to_cos=False, downscale_freq_shift=0, - ) - emb = emb.to(dtype) + ).to(dtype) emb = self.proj_in(emb) emb = self.act(emb) - device = next(self.proj_hid.parameters()).device - emb = emb.to(device) emb = self.proj_hid(emb) emb = self.act(emb) - device = next(self.proj_out.parameters()).device - emb = emb.to(device) emb = self.proj_out(emb) return emb @@ -1412,11 +1399,7 @@ class NaDiT(nn.Module): if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - device = next(self.parameters()).device - dtype = next(self.parameters()).dtype - txt = txt.to(device).to(dtype) - vid = vid.to(device).to(dtype) - txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) + txt = self.txt_in(txt) vid_shape_before_patchify = vid_shape vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 0f739cce5..9eae4bc52 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1,16 +1,16 @@ from contextlib import nullcontext from typing import Literal, Optional, Tuple +import gc import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import Tensor from contextlib import contextmanager +from comfy.utils import ProgressBar -import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention -from comfy_extras.nodes_seedvr import tiled_vae import math from enum import Enum @@ -20,9 +20,168 @@ import logging import comfy.ops ops = comfy.ops.disable_weight_init + +@torch.inference_mode() +def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True, **kwargs): + + gc.collect() + torch.cuda.empty_cache() + + x = x.to(next(vae_model.parameters()).dtype) + if x.ndim != 5: + x = x.unsqueeze(2) + + b, c, d, h, w = x.shape + + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) + sf_t = getattr(vae_model, "temporal_downsample_factor", 4) + + if encode: + ti_h, ti_w = tile_size + ov_h, ov_w = tile_overlap + target_d = (d + sf_t - 1) // sf_t + target_h = (h + sf_s - 1) // sf_s + target_w = (w + sf_s - 1) // sf_s + else: + ti_h = max(1, tile_size[0] // sf_s) + ti_w = max(1, tile_size[1] // sf_s) + ov_h = max(0, tile_overlap[0] // sf_s) + ov_w = max(0, tile_overlap[1] // sf_s) + + target_d = d * sf_t + target_h = h * sf_s + target_w = w * sf_s + + stride_h = max(1, ti_h - ov_h) + stride_w = max(1, ti_w - ov_w) + + storage_device = vae_model.device + result = None + count = None + + def run_temporal_chunks(spatial_tile): + chunk_results = [] + t_dim_size = spatial_tile.shape[2] + + if encode: + input_chunk = temporal_size + else: + input_chunk = max(1, temporal_size // sf_t) + for i in range(0, t_dim_size, input_chunk): + t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] + current_valid_len = t_chunk.shape[2] + + pad_amount = 0 + if current_valid_len < input_chunk: + pad_amount = input_chunk - current_valid_len + + last_frame = t_chunk[:, :, -1:, :, :] + padding = last_frame.repeat(1, 1, pad_amount, 1, 1) + + t_chunk = torch.cat([t_chunk, padding], dim=2) + t_chunk = t_chunk.contiguous() + + if encode: + out = vae_model.encode(t_chunk)[0] + else: + out = vae_model.decode_(t_chunk) + + if isinstance(out, (tuple, list)): + out = out[0] + if out.ndim == 4: + out = out.unsqueeze(2) + + if pad_amount > 0: + if encode: + expected_valid_out = (current_valid_len + sf_t - 1) // sf_t + out = out[:, :, :expected_valid_out, :, :] + + else: + expected_valid_out = current_valid_len * sf_t + out = out[:, :, :expected_valid_out, :, :] + + chunk_results.append(out.to(storage_device)) + + return torch.cat(chunk_results, dim=2) + + ramp_cache = {} + def get_ramp(steps): + if steps not in ramp_cache: + t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) + ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) + return ramp_cache[steps] + + total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w)) + bar = ProgressBar(total_tiles) + + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) + + tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] + + # Run VAE + tile_out = run_temporal_chunks(tile_x) + + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) + + if encode: + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) + else: + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) + + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) + + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: + w_h[:cur_ov_h] = r + if y_end < h: + w_h[-cur_ov_h:] = 1.0 - r + + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: + w_w[:cur_ov_w] = r + if x_end < w: + w_w[-cur_ov_w:] = 1.0 - r + + final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) + + valid_d = min(tile_out.shape[2], result.shape[2]) + tile_out = tile_out[:, :, :valid_d, :, :] + + tile_out.mul_(final_weight) + + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out + count[:, :, :, ys:ye, xs:xe] += final_weight + + del tile_out, final_weight, tile_x, w_h, w_w + bar.update(1) + + result.div_(count.clamp(min=1e-6)) + + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + + return result + _NORM_LIMIT = float("inf") - - def get_norm_limit(): return _NORM_LIMIT diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index c39792a8a..8aa166c48 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -6,208 +6,12 @@ from einops import rearrange import gc import comfy.model_management -from comfy.utils import ProgressBar import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode - -@torch.inference_mode() -def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True): - - gc.collect() - torch.cuda.empty_cache() - - x = x.to(next(vae_model.parameters()).dtype) - if x.ndim != 5: - x = x.unsqueeze(2) - - b, c, d, h, w = x.shape - - sf_s = getattr(vae_model, "spatial_downsample_factor", 8) - sf_t = getattr(vae_model, "temporal_downsample_factor", 4) - - if encode: - ti_h, ti_w = tile_size - ov_h, ov_w = tile_overlap - target_d = (d + sf_t - 1) // sf_t - target_h = (h + sf_s - 1) // sf_s - target_w = (w + sf_s - 1) // sf_s - else: - ti_h = max(1, tile_size[0] // sf_s) - ti_w = max(1, tile_size[1] // sf_s) - ov_h = max(0, tile_overlap[0] // sf_s) - ov_w = max(0, tile_overlap[1] // sf_s) - - target_d = d * sf_t - target_h = h * sf_s - target_w = w * sf_s - - stride_h = max(1, ti_h - ov_h) - stride_w = max(1, ti_w - ov_w) - - storage_device = vae_model.device - result = None - count = None - - def run_temporal_chunks(spatial_tile): - chunk_results = [] - t_dim_size = spatial_tile.shape[2] - - if encode: - input_chunk = temporal_size - else: - input_chunk = max(1, temporal_size // sf_t) - for i in range(0, t_dim_size, input_chunk): - t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] - current_valid_len = t_chunk.shape[2] - - pad_amount = 0 - if current_valid_len < input_chunk: - pad_amount = input_chunk - current_valid_len - - last_frame = t_chunk[:, :, -1:, :, :] - padding = last_frame.repeat(1, 1, pad_amount, 1, 1) - - t_chunk = torch.cat([t_chunk, padding], dim=2) - t_chunk = t_chunk.contiguous() - - if encode: - out = vae_model.encode(t_chunk)[0] - else: - out = vae_model.decode_(t_chunk) - - if isinstance(out, (tuple, list)): - out = out[0] - if out.ndim == 4: - out = out.unsqueeze(2) - - if pad_amount > 0: - if encode: - expected_valid_out = (current_valid_len + sf_t - 1) // sf_t - out = out[:, :, :expected_valid_out, :, :] - - else: - expected_valid_out = current_valid_len * sf_t - out = out[:, :, :expected_valid_out, :, :] - - chunk_results.append(out.to(storage_device)) - - return torch.cat(chunk_results, dim=2) - - ramp_cache = {} - def get_ramp(steps): - if steps not in ramp_cache: - t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) - ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) - return ramp_cache[steps] - - total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w)) - bar = ProgressBar(total_tiles) - - for y_idx in range(0, h, stride_h): - y_end = min(y_idx + ti_h, h) - - for x_idx in range(0, w, stride_w): - x_end = min(x_idx + ti_w, w) - - tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] - - # Run VAE - tile_out = run_temporal_chunks(tile_x) - - if result is None: - b_out, c_out = tile_out.shape[0], tile_out.shape[1] - result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) - count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) - - if encode: - ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] - xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) - else: - ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] - xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) - - w_h = torch.ones((tile_out.shape[3],), device=storage_device) - w_w = torch.ones((tile_out.shape[4],), device=storage_device) - - if cur_ov_h > 0: - r = get_ramp(cur_ov_h) - if y_idx > 0: - w_h[:cur_ov_h] = r - if y_end < h: - w_h[-cur_ov_h:] = 1.0 - r - - if cur_ov_w > 0: - r = get_ramp(cur_ov_w) - if x_idx > 0: - w_w[:cur_ov_w] = r - if x_end < w: - w_w[-cur_ov_w:] = 1.0 - r - - final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - - valid_d = min(tile_out.shape[2], result.shape[2]) - tile_out = tile_out[:, :, :valid_d, :, :] - - tile_out.mul_(final_weight) - - result[:, :, :valid_d, ys:ye, xs:xe] += tile_out - count[:, :, :, ys:ye, xs:xe] += final_weight - - del tile_out, final_weight, tile_x, w_h, w_w - bar.update(1) - - result.div_(count.clamp(min=1e-6)) - - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - - return result - -def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False): - t = videos.size(temporal_dim) - - if count == 0 and not prepend: - if t % 4 == 1: - return videos - count = ((t - 1) // 4 + 1) * 4 + 1 - t - - if count <= 0: - return videos - - def select(start, end): - return videos[start:end] if temporal_dim == 0 else videos[:, start:end] - - if count >= t: - repeat_count = count - t + 1 - last = select(-1, None) - - if temporal_dim == 0: - repeated = last.repeat(repeat_count, 1, 1, 1) - reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0] - else: - repeated = last.expand(-1, repeat_count, -1, -1).contiguous() - reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0] - - return torch.cat([repeated, reversed_frames, videos] if prepend else - [videos, reversed_frames, repeated], dim=temporal_dim) - - if prepend: - reversed_frames = select(1, count+1).flip(temporal_dim) - else: - reversed_frames = select(-count-1, -1).flip(temporal_dim) - - return torch.cat([reversed_frames, videos] if prepend else - [videos, reversed_frames], dim=temporal_dim) +from comfy.ldm.seedvr.vae import tiled_vae def clear_vae_memory(vae_model): for module in vae_model.modules():