From 1afc2ed8e60b5206c2475825b4191309a6ad5234 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 24 Dec 2025 02:23:57 +0200 Subject: [PATCH] fixed the speed issue --- comfy/ldm/seedvr/vae.py | 488 ++++++++++++++++++++++++++++------- comfy_extras/nodes_seedvr.py | 157 +++++------ 2 files changed, 485 insertions(+), 160 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index da2bb2c2f..0c7fa5c5f 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -5,12 +5,57 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import Tensor +from contextlib import contextmanager import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention from comfy_extras.nodes_seedvr import tiled_vae +import math +from enum import Enum +from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND + +_NORM_LIMIT = float("inf") + + +def get_norm_limit(): + return _NORM_LIMIT + + +def set_norm_limit(value: Optional[float] = None): + global _NORM_LIMIT + if value is None: + value = float("inf") + _NORM_LIMIT = value + +@contextmanager +def ignore_padding(model): + orig_padding = model.padding + model.padding = (0, 0, 0) + try: + yield + finally: + model.padding = orig_padding + +class MemoryState(Enum): + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + UNSET = 3 + +def get_cache_size(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + remain_len = ( + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + ) + overlap_len = dilated_kernerl_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len # >= 0 + + assert output_len > 0 + return cache_len + class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters @@ -34,6 +79,9 @@ class DiagonalGaussianDistribution(object): x = self.mean + self.std * sample return x + def mode(self): + return self.mean + class SpatialNorm(nn.Module): def __init__( self, @@ -366,41 +414,233 @@ def extend_head(tensor, times: int = 2, memory = None): tile_repeat[2] = times return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) -class InflatedCausalConv3d(nn.Conv3d): +def cache_send_recv(tensor, cache_size, times, memory=None): + # Single GPU inference - simplified cache handling + recv_buffer = None + + # Handle memory buffer for single GPU case + if memory is not None: + recv_buffer = memory.to(tensor[0]) + elif times > 0: + tile_repeat = [1] * tensor[0].ndim + tile_repeat[2] = times + recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) + + return recv_buffer + +class InflatedCausalConv3d(torch.nn.Conv3d): def __init__( self, *args, inflation_mode, + memory_device = "same", **kwargs, ): self.inflation_mode = inflation_mode self.memory = None super().__init__(*args, **kwargs) self.temporal_padding = self.padding[0] + self.memory_device = memory_device self.padding = (0, *self.padding[1:]) self.memory_limit = float("inf") + def set_memory_limit(self, value: float): + self.memory_limit = value + + def set_memory_device(self, memory_device): + self.memory_device = memory_device + + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and + weight.dtype in (torch.float16, torch.bfloat16) and + hasattr(torch.backends.cudnn, 'is_available') and + torch.backends.cudnn.is_available() and + getattr(torch.backends.cudnn, 'enabled', True)): + try: + out = torch.cudnn_convolution( + input, weight, self.padding, self.stride, self.dilation, self.groups, + benchmark=False, deterministic=False, allow_tf32=True + ) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + except RuntimeError: + pass + + return super()._conv_forward(input, weight, bias, *args, **kwargs) + + def memory_limit_conv( + self, + x, + *, + split_dim=3, + padding=(0, 0, 0, 0, 0, 0), + prev_cache=None, + ): + # Compatible with no limit. + if math.isinf(self.memory_limit): + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + return super().forward(x) + + # Compute tensor shape after concat & padding. + shape = torch.tensor(x.size()) + if prev_cache is not None: + shape[split_dim - 1] += prev_cache.size(split_dim - 1) + shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) + memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + if memory_occupy < self.memory_limit or split_dim == x.ndim: + x_concat = x + if prev_cache is not None: + x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) + + def pad_and_forward(): + padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + with ignore_padding(self): + return torch.nn.Conv3d.forward(self, padded) + + return pad_and_forward() + + num_splits = math.ceil(memory_occupy / self.memory_limit) + size_per_split = x.size(split_dim) // num_splits + split_sizes = [size_per_split] * (num_splits - 1) + split_sizes += [x.size(split_dim) - sum(split_sizes)] + + x = list(x.split(split_sizes, dim=split_dim)) + if prev_cache is not None: + prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) + cache = None + for idx in range(len(x)): + if prev_cache is not None: + x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) + + lpad_dim = (x[idx].ndim - split_dim - 1) * 2 + rpad_dim = lpad_dim + 1 + padding = list(padding) + padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 + padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 + pad_len = padding[lpad_dim] + padding[rpad_dim] + padding = tuple(padding) + + next_cache = None + cache_len = cache.size(split_dim) if cache is not None else 0 + next_catch_size = get_cache_size( + conv_module=self, + input_len=x[idx].size(split_dim) + cache_len, + pad_len=pad_len, + dim=split_dim - 2, + ) + if next_catch_size != 0: + assert next_catch_size <= x[idx].size(split_dim) + next_cache = ( + x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + ) + + x[idx] = self.memory_limit_conv( + x[idx], + split_dim=split_dim + 1, + padding=padding, + prev_cache=cache + ) + + cache = next_cache + + output = torch.cat(x, dim=split_dim) + return output + def forward( self, input, - ): - input = extend_head(input, times=self.temporal_padding * 2) + memory_state: MemoryState = MemoryState.UNSET + ) -> Tensor: + assert memory_state != MemoryState.UNSET + if memory_state != MemoryState.ACTIVE: + self.memory = None + if ( + math.isinf(self.memory_limit) + and torch.is_tensor(input) + ): + return self.basic_forward(input, memory_state) + return self.slicing_forward(input, memory_state) + + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory, times=-1) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if ( + memory_state != MemoryState.DISABLED + and not self.training + and (self.memory_device is not None) + ): + self.memory = memory + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") return super().forward(input) - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): + def slicing_forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + squeeze_out = False + if torch.is_tensor(input): + input = [input] + squeeze_out = True - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, + cache_size = self.kernel_size[0] - self.stride[0] + cache = cache_send_recv( + input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 ) + # Single GPU inference - simplified memory management + if ( + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + and not self.training + and (self.memory_device is not None) + and cache_size != 0 + ): + if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: + input[0] = torch.cat([cache, input[0]], dim=2) + cache = None + if cache_size <= input[-1].size(2): + self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + + padding = tuple(x for x in reversed(self.padding) for _ in range(2)) + for i in range(len(input)): + # Prepare cache for next input slice. + next_cache = None + cache_size = 0 + if i < len(input) - 1: + cache_len = cache.size(2) if cache is not None else 0 + cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) + if cache_size != 0: + if cache_size > input[i].size(2) and cache is not None: + input[i] = torch.cat([cache, input[i]], dim=2) + cache = None + assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + next_cache = input[i][:, :, -cache_size:] + + # Conv forward for this input slice. + input[i] = self.memory_limit_conv( + input[i], + padding=padding, + prev_cache=cache + ) + + # Update cache. + cache = next_cache + + return input[0] if squeeze_out else input + def remove_head(tensor: Tensor, times: int = 1) -> Tensor: if times == 0: return tensor @@ -488,6 +728,7 @@ class Upsample3D(nn.Module): def forward( self, hidden_states: torch.FloatTensor, + memory_state=None, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels @@ -517,7 +758,7 @@ class Upsample3D(nn.Module): z=self.temporal_ratio, ) - if self.temporal_up: + if self.temporal_up and memory_state != MemoryState.ACTIVE: hidden_states[0] = remove_head(hidden_states[0]) if not self.slicing: @@ -525,9 +766,9 @@ class Upsample3D(nn.Module): if self.use_conv: if self.name == "conv": - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, memory_state=memory_state) else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) if not self.slicing: return hidden_states @@ -594,6 +835,7 @@ class Downsample3D(nn.Module): def forward( self, hidden_states: torch.FloatTensor, + memory_state = None, **kwargs, ) -> torch.FloatTensor: @@ -609,7 +851,7 @@ class Downsample3D(nn.Module): assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, memory_state=memory_state) return hidden_states @@ -707,7 +949,7 @@ class ResnetBlock3D(nn.Module): ) def forward( - self, input_tensor, temb, **kwargs + self, input_tensor, temb, memory_state = None, **kwargs ): hidden_states = input_tensor @@ -719,13 +961,13 @@ class ResnetBlock3D(nn.Module): if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) + input_tensor = self.upsample(input_tensor, memory_state=memory_state) + hidden_states = self.upsample(hidden_states, memory_state=memory_state) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) + input_tensor = self.downsample(input_tensor, memory_state=memory_state) + hidden_states = self.downsample(hidden_states, memory_state=memory_state) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, memory_state=memory_state) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -740,10 +982,10 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor @@ -819,15 +1061,16 @@ class DownEncoderBlock3D(nn.Module): def forward( self, hidden_states: torch.FloatTensor, + memory_state = None, **kwargs, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, memory_state=memory_state) return hidden_states @@ -907,14 +1150,15 @@ class UpDecoderBlock3D(nn.Module): self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, + memory_state=None ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, memory_state=memory_state) return hidden_states @@ -1008,9 +1252,9 @@ class UNetMidBlock3D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, memory_state=None): video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") @@ -1018,7 +1262,7 @@ class UNetMidBlock3D(nn.Module): hidden_states = rearrange( hidden_states, "(b f) c h w -> b c f h w", f=video_length ) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state) return hidden_states @@ -1136,10 +1380,11 @@ class Encoder3D(nn.Module): self, sample: torch.FloatTensor, extra_cond=None, + memory_state = None ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample) + sample = self.conv_in(sample, memory_state = memory_state) if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -1164,17 +1409,17 @@ class Encoder3D(nn.Module): # down # [Override] add extra block and extra cond for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = down_block(sample) + sample = down_block(sample, memory_state=memory_state) if extra_block is not None: sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) # middle - sample = self.mid_block(sample) + sample = self.mid_block(sample, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample) + sample = self.conv_out(sample, memory_state = memory_state) return sample @@ -1282,74 +1527,90 @@ class Decoder3D(nn.Module): self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, + memory_state = None, ) -> torch.FloatTensor: sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample) + sample = self.conv_in(sample, memory_state=memory_state) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle - sample = self.mid_block(sample, latent_embeds) + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds) + sample = up_block(sample, latent_embeds, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample) + sample = self.conv_out(sample, memory_state=memory_state) return sample -def wavelet_blur(image: Tensor, radius: int): - """ - Apply wavelet blur to the input tensor. - """ - # input shape: (1, 3, H, W) - # convolution kernel +def wavelet_blur(image: Tensor, radius): + max_safe_radius = max(1, min(image.shape[-2:]) // 8) + if radius > max_safe_radius: + radius = max_safe_radius + + num_channels = image.shape[1] + kernel_vals = [ [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], + [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - # add channel dimensions to the kernel to make it a 4D tensor - kernel = kernel[None, None] - # repeat the kernel across all input channels - kernel = kernel.repeat(3, 1, 1, 1) - image = F.pad(image, (radius, radius, radius, radius), mode='replicate') - # apply convolution - output = F.conv2d(image, kernel, groups=3, dilation=radius) + kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) + + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) + return output -def wavelet_decomposition(image: Tensor, levels=5): - """ - Apply wavelet decomposition to the input tensor. - This function only returns the low frequency & the high frequency. - """ +def wavelet_decomposition(image: Tensor, levels: int = 5): high_freq = torch.zeros_like(image) + for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) - high_freq += (image - low_freq) + high_freq.add_(image).sub_(low_freq) image = low_freq - + return high_freq, low_freq -def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): - """ - Apply wavelet decomposition, so that the content will have the same color as the style. - """ - # calculate the wavelet decomposition of the content feature +def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: + + if content_feat.shape != style_feat.shape: + # Resize style to match content spatial dimensions + if len(content_feat.shape) >= 3: + # safe_interpolate_operation handles FP16 conversion automatically + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Decompose both features into frequency components content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq - # calculate the wavelet decomposition of the style feature - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq - # reconstruct the content feature with the style's high frequency - return content_high_freq + style_low_freq + del content_low_freq # Free memory immediately + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq # Free memory immediately + + if content_high_freq.shape != style_low_freq.shape: + style_low_freq = safe_interpolate_operation( + style_low_freq, + size=content_high_freq.shape[-2:], + mode='bilinear', + align_corners=False + ) + + content_high_freq.add_(style_low_freq) + + return content_high_freq.clamp_(-1.0, 1.0) class VideoAutoencoderKL(nn.Module): def __init__( @@ -1368,9 +1629,12 @@ class VideoAutoencoderKL(nn.Module): time_receptive_field: _receptive_field_t = "full", use_quant_conv: bool = False, use_post_quant_conv: bool = False, + slicing_sample_min_size = 4, *args, **kwargs, ): + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None block_out_channels = (128, 256, 512, 512) down_block_types = ("DownEncoderBlock3D",) * 4 @@ -1438,9 +1702,11 @@ class VideoAutoencoderKL(nn.Module): self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.use_slicing = True + def encode(self, x: torch.FloatTensor, return_dict: bool = True): h = self.slicing_encode(x) - posterior = DiagonalGaussianDistribution(h).sample() + posterior = DiagonalGaussianDistribution(h).mode() if not return_dict: return (posterior,) @@ -1458,30 +1724,72 @@ class VideoAutoencoderKL(nn.Module): return decoded def _encode( - self, x: torch.Tensor + self, x, memory_state ) -> torch.Tensor: _x = x.to(self.device) - h = self.encoder(_x,) + h = self.encoder(_x, memory_state=memory_state) if self.quant_conv is not None: - output = self.quant_conv(h) + output = self.quant_conv(h, memory_state=memory_state) else: output = h return output.to(x.device) def _decode( - self, z: torch.Tensor + self, z, memory_state ) -> torch.Tensor: - latent = z.to(self.device) + _z = z.to(self.device) + if self.post_quant_conv is not None: - latent = self.post_quant_conv(latent) - output = self.decoder(latent) + _z = self.post_quant_conv(_z, memory_state=memory_state) + + output = self.decoder(_z, memory_state=memory_state) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - return self._encode(x) + sp_size =1 + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(encoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - return self._decode(z) + sp_size = 1 + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(decoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._decode(z) def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError @@ -1531,6 +1839,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): self.freeze_encoder = freeze_encoder self.original_image_video = None super().__init__(*args, **kwargs) + self.set_memory_limit(0.5, 0.5) def forward(self, x: torch.FloatTensor): with torch.no_grad() if self.freeze_encoder else nullcontext(): @@ -1567,8 +1876,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): target_device = comfy.model_management.get_torch_device() self.decoder.to(target_device) - x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) - #x = super().decode(latent).squeeze(2) + if self.tiled_args.get("enable_tiling", None) is not None: + self.enable_tiling = self.tiled_args.pop("enable_tiling", False) + + if self.enable_tiling: + x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) + else: + x = super().decode_(latent).squeeze(2) input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w") if x.ndim == 4: @@ -1581,6 +1895,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = rearrange(x, "b c t h w -> (b t) c h w") x = wavelet_reconstruction(x, input) + x = x.unsqueeze(0) o_h, o_w = self.img_dims x = x[..., :o_h, :o_w] @@ -1595,8 +1910,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): - # TODO - #set_norm_limit(norm_max_mem) + set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index ce5437517..8380e4feb 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -14,25 +14,23 @@ from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode @torch.inference_mode() -def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): +def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True): gc.collect() torch.cuda.empty_cache() + x = x.to(next(vae_model.parameters()).dtype) if x.ndim != 5: x = x.unsqueeze(2) b, c, d, h, w = x.shape - + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_t = getattr(vae_model, "temporal_downsample_factor", 4) if encode: ti_h, ti_w = tile_size ov_h, ov_w = tile_overlap - ti_t = temporal_size - ov_t = temporal_overlap - target_d = (d + sf_t - 1) // sf_t target_h = (h + sf_s - 1) // sf_s target_w = (w + sf_s - 1) // sf_s @@ -41,21 +39,44 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ti_w = max(1, tile_size[1] // sf_s) ov_h = max(0, tile_overlap[0] // sf_s) ov_w = max(0, tile_overlap[1] // sf_s) - ti_t = max(1, temporal_size // sf_t) - ov_t = max(0, temporal_overlap // sf_t) - + target_d = d * sf_t target_h = h * sf_s target_w = w * sf_s - stride_t = max(1, ti_t - ov_t) stride_h = max(1, ti_h - ov_h) stride_w = max(1, ti_w - ov_w) storage_device = torch.device("cpu") + result = None count = None + def run_temporal_chunks(spatial_tile): + chunk_results = [] + t_dim_size = spatial_tile.shape[2] + + if encode: + input_chunk = temporal_size + else: + input_chunk = max(1, temporal_size // sf_t) + + for i in range(0, t_dim_size, input_chunk): + t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] + + if encode: + out = vae_model.encode(t_chunk) + else: + out = vae_model.decode_(t_chunk) + + if isinstance(out, (tuple, list)): out = out[0] + + if out.ndim == 4: out = out.unsqueeze(2) + + chunk_results.append(out.to(storage_device)) + + return torch.cat(chunk_results, dim=2) + ramp_cache = {} def get_ramp(steps): if steps not in ramp_cache: @@ -63,79 +84,64 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) return ramp_cache[steps] - bar = ProgressBar(d // stride_t) - for t_idx in range(0, d, stride_t): - t_end = min(t_idx + ti_t, d) + total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w)) + bar = ProgressBar(total_tiles) - for y_idx in range(0, h, stride_h): - y_end = min(y_idx + ti_h, h) + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) - for x_idx in range(0, w, stride_w): - x_end = min(x_idx + ti_w, w) + tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] - tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end] + # Run VAE + tile_out = run_temporal_chunks(tile_x) - if encode: - tile_out = vae_model.encode(tile_x)[0] - else: - tile_out = vae_model.decode_(tile_x) + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) - if tile_out.ndim == 4: - tile_out = tile_out.unsqueeze(2) + if encode: + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) + else: + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) - tile_out = tile_out.to(storage_device).float() + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) - if result is None: - b_out, c_out = tile_out.shape[0], tile_out.shape[1] - result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) - count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: w_h[:cur_ov_h] = r + if y_end < h: w_h[-cur_ov_h:] = 1.0 - r - if encode: - ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] - ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] - xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: w_w[:cur_ov_w] = r + if x_end < w: w_w[-cur_ov_w:] = 1.0 - r - cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) - cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) - else: - ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2] - ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] - xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2)) - cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) + valid_d = min(tile_out.shape[2], result.shape[2]) + tile_out = tile_out[:, :, :valid_d, :, :] + + tile_out.mul_(final_weight) + + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out + count[:, :, :, ys:ye, xs:xe] += final_weight - w_t = torch.ones((tile_out.shape[2],), device=storage_device) - w_h = torch.ones((tile_out.shape[3],), device=storage_device) - w_w = torch.ones((tile_out.shape[4],), device=storage_device) + del tile_out, final_weight, tile_x, w_h, w_w + bar.update(1) - if cur_ov_t > 0: - r = get_ramp(cur_ov_t) - if t_idx > 0: w_t[:cur_ov_t] = r - if t_end < d: w_t[-cur_ov_t:] = 1.0 - r - - if cur_ov_h > 0: - r = get_ramp(cur_ov_h) - if y_idx > 0: w_h[:cur_ov_h] = r - if y_end < h: w_h[-cur_ov_h:] = 1.0 - r - - if cur_ov_w > 0: - r = get_ramp(cur_ov_w) - if x_idx > 0: w_w[:cur_ov_w] = r - if x_end < w: w_w[-cur_ov_w:] = 1.0 - r - - final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - - tile_out.mul_(final_weight) - result[:, :, ts:te, ys:ye, xs:xe] += tile_out - count[:, :, ts:te, ys:ye, xs:xe] += final_weight - - del tile_out, final_weight, tile_x, w_t, w_h, w_w - bar.update(1) result.div_(count.clamp(min=1e-6)) - + if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -253,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Int.Input("spatial_tile_size", default = 512, min = -1), io.Int.Input("temporal_tile_size", default = 8, min = -1), io.Int.Input("spatial_overlap", default = 64, min = -1), - io.Int.Input("temporal_overlap", default = 8, min = -1), + io.Boolean.Input("enable_tiling", default=False) ], outputs = [ io.Latent.Output("vae_conditioning") @@ -261,7 +267,7 @@ class SeedVR2InputProcessing(io.ComfyNode): ) @classmethod - def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): + def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -296,9 +302,14 @@ class SeedVR2InputProcessing(io.ComfyNode): vae_model.original_image_video = images args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), - "temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap} + "temporal_size":temporal_tile_size} + if enable_tiling: + latent = tiled_vae(images, vae_model, encode=True, **args) + else: + latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0] + + args["enable_tiling"] = enable_tiling vae_model.tiled_args = args - latent = tiled_vae(images, vae_model, encode=True, **args) vae_model = vae_model.to(offload_device) vae_model.img_dims = [o_h, o_w]