From 21bc67d7db037d652d3d5fc65087261fc5411b96 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 02:08:59 +0200 Subject: [PATCH] final changes --- comfy/ldm/seedvr/model.py | 70 +++++++++++++++++++----------- comfy/ldm/seedvr/vae.py | 83 +++++++++++++++++++----------------- comfy_extras/nodes_seedvr.py | 61 +++++++++++++++----------- 3 files changed, 125 insertions(+), 89 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 7578c0be5..bd0057332 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -526,22 +526,22 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): max_height = 0 max_width = 0 max_txt_len = 0 - + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal max_height = max(max_height, h) max_width = max(max_width, w) max_txt_len = max(max_txt_len, l) - + # Compute frequencies for actual max dimensions needed # Add small buffer to improve cache hits across similar batches vid_freqs = self.get_axial_freqs( min(max_temporal + 16, 1024), # Cap at 1024, add small buffer - min(max_height + 4, 128), # Cap at 128, add small buffer + min(max_height + 4, 128), # Cap at 128, add small buffer min(max_width + 4, 128) # Cap at 128, add small buffer ) txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024)) - + # Now slice as before vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): @@ -615,6 +615,7 @@ class NaMMAttention(nn.Module): rope_type: Optional[str], rope_dim: int, shared_weights: bool, + device, dtype, operations, **kwargs, ): super().__init__() @@ -624,15 +625,16 @@ class NaMMAttention(nn.Module): qkv_dim = inner_dim * 3 self.head_dim = head_dim self.proj_qkv = MMModule( - nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights + operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype ) - self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) + self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) self.norm_q = MMModule( qk_norm, normalized_shape=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, + device=device, dtype=dtype ) self.norm_k = MMModule( qk_norm, @@ -640,6 +642,7 @@ class NaMMAttention(nn.Module): eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, + device=device, dtype=dtype ) @@ -795,11 +798,12 @@ class MLP(nn.Module): self, dim: int, expand_ratio: int, + device, dtype, operations ): super().__init__() - self.proj_in = nn.Linear(dim, dim * expand_ratio) + self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) self.act = nn.GELU("tanh") - self.proj_out = nn.Linear(dim * expand_ratio, dim) + self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.proj_in(x) @@ -814,13 +818,14 @@ class SwiGLUMLP(nn.Module): dim: int, expand_ratio: int, multiple_of: int = 256, + device=None, dtype=None, operations=None ): super().__init__() hidden_dim = int(2 * dim * expand_ratio / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) - self.proj_out = nn.Linear(hidden_dim, dim, bias=False) - self.proj_in = nn.Linear(dim, hidden_dim, bias=False) + self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) + self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = x.to(next(self.proj_in.parameters()).device) @@ -855,11 +860,12 @@ class NaMMSRTransformerBlock(nn.Module): rope_type: str, rope_dim: int, is_last_layer: bool, + device, dtype, operations, **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) self.attn = NaSwinAttention( vid_dim=vid_dim, @@ -874,17 +880,19 @@ class NaMMSRTransformerBlock(nn.Module): shared_weights=shared_weights, window=kwargs.pop("window", None), window_method=kwargs.pop("window_method", None), + device=device, dtype=dtype, operations=operations ) - self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) self.mlp = MMModule( get_mlp(mlp_type), dim=dim, expand_ratio=expand_ratio, shared_weights=shared_weights, - vid_only=is_last_layer + vid_only=is_last_layer, + device=device, dtype=dtype, operations=operations ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) self.is_last_layer = is_last_layer def forward( @@ -933,11 +941,12 @@ class PatchOut(nn.Module): out_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, + device, dtype, operations ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w - self.proj = nn.Linear(dim, out_channels * t * h * w) + self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) def forward( self, @@ -981,11 +990,12 @@ class PatchIn(nn.Module): in_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, + device, dtype, operations ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w - self.proj = nn.Linear(in_channels * t * h * w, dim) + self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) def forward( self, @@ -1033,6 +1043,7 @@ class AdaSingle(nn.Module): emb_dim: int, layers: List[str], modes: List[str] = ["in", "out"], + device = None, dtype = None, ): assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" super().__init__() @@ -1041,12 +1052,12 @@ class AdaSingle(nn.Module): self.layers = layers for l in layers: if "in" in modes: - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5)) self.register_parameter( f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) ) if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5)) def forward( self, @@ -1096,12 +1107,13 @@ class TimeEmbedding(nn.Module): sinusoidal_dim: int, hidden_dim: int, output_dim: int, + device, dtype, operations ): super().__init__() self.sinusoidal_dim = sinusoidal_dim - self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) - self.proj_hid = nn.Linear(hidden_dim, hidden_dim) - self.proj_out = nn.Linear(hidden_dim, output_dim) + self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) + self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) self.act = nn.SiLU() def forward( @@ -1199,6 +1211,7 @@ class NaDiT(nn.Module): **kwargs, ): self.dtype = dtype + factory_kwargs = {"device": device, "dtype": dtype} window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] txt_dim = vid_dim emb_dim = vid_dim * 6 @@ -1212,15 +1225,16 @@ class NaDiT(nn.Module): elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() - self.register_buffer("positive_conditioning", torch.empty((58, 5120))) - self.register_buffer("negative_conditioning", torch.empty((64, 5120))) + self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype)) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, dim=vid_dim, + device=device, dtype=dtype, operations=operations ) self.txt_in = ( - nn.Linear(txt_in_dim, txt_dim) + operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) if txt_in_dim and txt_in_dim != txt_dim else nn.Identity() ) @@ -1228,6 +1242,7 @@ class NaDiT(nn.Module): sinusoidal_dim=256, hidden_dim=max(vid_dim, txt_dim), output_dim=emb_dim, + device=device, dtype=dtype, operations=operations ) if window is None or isinstance(window[0], int): @@ -1268,7 +1283,9 @@ class NaDiT(nn.Module): shared_weights=not ( (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] ), + operations = operations, **kwargs, + **factory_kwargs ) for i in range(num_layers) ] @@ -1277,6 +1294,7 @@ class NaDiT(nn.Module): out_channels=vid_out_channels, patch_size=patch_size, dim=vid_dim, + device=device, dtype=dtype, operations=operations ) self.need_txt_repeat = block_type[0] in [ @@ -1291,12 +1309,14 @@ class NaDiT(nn.Module): normalized_shape=vid_dim, eps=norm_eps, elementwise_affine=True, + device=device, dtype=dtype ) self.vid_out_ada = ada( dim=vid_dim, emb_dim=emb_dim, layers=["out"], modes=["in"], + device=device, dtype=dtype ) self.stop_cfg_index = -1 diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 0c7fa5c5f..9fcea60ad 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -16,6 +16,9 @@ import math from enum import Enum from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND +import comfy.ops +ops = comfy.ops.disable_weight_init + _NORM_LIMIT = float("inf") @@ -89,9 +92,9 @@ class SpatialNorm(nn.Module): zq_channels: int, ): super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: f_size = f.shape[-2:] @@ -164,7 +167,7 @@ class Attention(nn.Module): self.only_cross_attention = only_cross_attention if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None @@ -177,22 +180,22 @@ class Attention(nn.Module): self.norm_k = None self.norm_cross = None - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) else: self.to_k = None self.to_v = None self.added_proj_bias = added_proj_bias if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) else: self.add_q_proj = None self.add_k_proj = None @@ -200,13 +203,13 @@ class Attention(nn.Module): if not self.pre_only: self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) else: self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) else: self.to_add_out = None @@ -325,7 +328,7 @@ def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype - if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): + if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): if x.ndim == 4: x = rearrange(x, "b c h w -> b h w c") x = norm_layer(x) @@ -336,14 +339,14 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = norm_layer(x) x = rearrange(x, "b t h w c -> b c t h w") return x.to(input_dtype) - if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): if x.ndim <= 4: return norm_layer(x).to(input_dtype) if x.ndim == 5: t = x.size(2) x = rearrange(x, "b c t h w -> (b t) c h w") memory_occupy = x.numel() * x.element_size() / 1024**3 - if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae + if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) assert norm_layer.num_groups % num_chunks == 0 num_groups_per_chunk = norm_layer.num_groups // num_chunks @@ -428,7 +431,7 @@ def cache_send_recv(tensor, cache_size, times, memory=None): return recv_buffer -class InflatedCausalConv3d(torch.nn.Conv3d): +class InflatedCausalConv3d(ops.Conv3d): def __init__( self, *args, @@ -677,17 +680,16 @@ class Upsample3D(nn.Module): if use_conv_transpose: if kernel_size is None: kernel_size = 4 - self.conv = nn.ConvTranspose2d( + self.conv = ops.ConvTranspose2d( channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias ) elif use_conv: if kernel_size is None: kernel_size = 3 - self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) conv = self.conv if self.name == "conv" else self.Conv2d_0 - assert type(conv) is not nn.ConvTranspose2d # Note: lora_layer is not passed into constructor in the original implementation. # So we make a simplification. conv = InflatedCausalConv3d( @@ -708,7 +710,7 @@ class Upsample3D(nn.Module): # [Override] MAGViT v2 implementation if not self.interpolate: upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = nn.Conv3d( + self.upscale_conv = ops.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) identity = ( @@ -892,13 +894,13 @@ class ResnetBlock3D(nn.Module): self.skip_time_act = skip_time_act self.nonlinearity = nn.SiLU() if temb_channels is not None: - self.time_emb_proj = nn.Linear(temb_channels, out_channels) + self.time_emb_proj = ops.Linear(temb_channels, out_channels) else: self.time_emb_proj = None - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) if groups_out is None: groups_out = groups - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.use_in_shortcut = self.in_channels != out_channels self.dropout = torch.nn.Dropout(dropout) self.conv1 = InflatedCausalConv3d( @@ -1342,7 +1344,7 @@ class Encoder3D(nn.Module): self.conv_extra_cond.append( zero_module( - nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) ) if self.extra_cond_dim is not None and self.extra_cond_dim > 0 else None @@ -1364,7 +1366,7 @@ class Encoder3D(nn.Module): ) # out - self.conv_norm_out = nn.GroupNorm( + self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() @@ -1512,7 +1514,7 @@ class Decoder3D(nn.Module): if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: - self.conv_norm_out = nn.GroupNorm( + self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() @@ -1553,9 +1555,9 @@ def wavelet_blur(image: Tensor, radius): max_safe_radius = max(1, min(image.shape[-2:]) // 8) if radius > max_safe_radius: radius = max_safe_radius - + num_channels = image.shape[1] - + kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], @@ -1563,21 +1565,21 @@ def wavelet_blur(image: Tensor, radius): ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) - + return output def wavelet_decomposition(image: Tensor, levels: int = 5): high_freq = torch.zeros_like(image) - + for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq.add_(image).sub_(low_freq) image = low_freq - + return high_freq, low_freq def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: @@ -1587,19 +1589,19 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: if len(content_feat.shape) >= 3: # safe_interpolate_operation handles FP16 conversion automatically style_feat = safe_interpolate_operation( - style_feat, + style_feat, size=content_feat.shape[-2:], - mode='bilinear', + mode='bilinear', align_corners=False ) - + # Decompose both features into frequency components content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # Free memory immediately - - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # Free memory immediately - + if content_high_freq.shape != style_low_freq.shape: style_low_freq = safe_interpolate_operation( style_low_freq, @@ -1607,9 +1609,9 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: mode='bilinear', align_corners=False ) - + content_high_freq.add_(style_low_freq) - + return content_high_freq.clamp_(-1.0, 1.0) class VideoAutoencoderKL(nn.Module): @@ -1894,6 +1896,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = rearrange(x, "b c t h w -> (b t) c h w") + input = input.to(x.device) x = wavelet_reconstruction(x, input) x = x.unsqueeze(0) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e6ccd44c1..bd0c6037a 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -24,7 +24,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora x = x.unsqueeze(2) b, c, d, h, w = x.shape - + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_t = getattr(vae_model, "temporal_downsample_factor", 4) @@ -39,7 +39,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ti_w = max(1, tile_size[1] // sf_s) ov_h = max(0, tile_overlap[0] // sf_s) ov_w = max(0, tile_overlap[1] // sf_s) - + target_d = d * sf_t target_h = h * sf_s target_w = w * sf_s @@ -47,15 +47,14 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora stride_h = max(1, ti_h - ov_h) stride_w = max(1, ti_w - ov_w) - storage_device = torch.device("cpu") - + storage_device = vae_model.device result = None count = None def run_temporal_chunks(spatial_tile): chunk_results = [] t_dim_size = spatial_tile.shape[2] - + if encode: input_chunk = temporal_size else: @@ -63,18 +62,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora for i in range(0, t_dim_size, input_chunk): t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] - + if encode: - out = vae_model.slicing_encode(t_chunk) + out = vae_model.encode(t_chunk) else: - out = vae_model.slicing_decode(t_chunk) - + out = vae_model.decode_(t_chunk) + if isinstance(out, (tuple, list)): out = out[0] - + if out.ndim == 4: out = out.unsqueeze(2) - - chunk_results.append(out.to(storage_device)) - + + chunk_results.append(out.to(storage_device)) + return torch.cat(chunk_results, dim=2) ramp_cache = {} @@ -89,7 +88,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora for y_idx in range(0, h, stride_h): y_end = min(y_idx + ti_h, h) - + for x_idx in range(0, w, stride_w): x_end = min(x_idx + ti_w, w) @@ -131,9 +130,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora valid_d = min(tile_out.shape[2], result.shape[2]) tile_out = tile_out[:, :, :valid_d, :, :] - + tile_out.mul_(final_weight) - + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out count[:, :, :, ys:ye, xs:xe] += final_weight @@ -141,7 +140,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora bar.update(1) result.div_(count.clamp(min=1e-6)) - + if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -150,6 +149,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora return result +def clear_vae_memory(vae_model): + for module in vae_model.modules(): + if hasattr(module, "memory"): + module.memory = None + if hasattr(vae_model, "original_image_video"): + del vae_model.original_image_video + + if hasattr(vae_model, "tiled_args"): + del vae_model.tiled_args + gc.collect() + torch.cuda.empty_cache() + def expand_dims(tensor, ndim): shape = tensor.shape + (1,) * (ndim - tensor.ndim) return tensor.reshape(shape) @@ -261,9 +272,9 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Vae.Input("vae"), io.Int.Input("resolution_height", default = 1280, min = 120), # // io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value - io.Int.Input("spatial_tile_size", default = 512, min = -1), - io.Int.Input("temporal_tile_size", default = 8, min = -1), - io.Int.Input("spatial_overlap", default = 64, min = -1), + io.Int.Input("spatial_tile_size", default = 512, min = 1), + io.Int.Input("temporal_tile_size", default = 8, min = 1), + io.Int.Input("spatial_overlap", default = 64, min = 1), io.Boolean.Input("enable_tiling", default=False) ], outputs = [ @@ -305,7 +316,6 @@ class SeedVR2InputProcessing(io.ComfyNode): images = rearrange(images, "b t c h w -> b c t h w") images = images.to(device) vae_model = vae_model.to(device) - vae_model.original_image_video = images args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), "temporal_size":temporal_tile_size} @@ -314,11 +324,14 @@ class SeedVR2InputProcessing(io.ComfyNode): else: latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0] + clear_vae_memory(vae_model) + #images = images.to(offload_device) + #vae_model = vae_model.to(offload_device) + + vae_model.img_dims = [o_h, o_w] args["enable_tiling"] = enable_tiling vae_model.tiled_args = args - - vae_model = vae_model.to(offload_device) - vae_model.img_dims = [o_h, o_w] + vae_model.original_image_video = images latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c")