fixed the speed issue

This commit is contained in:
Yousef Rafat 2025-12-24 02:23:57 +02:00
parent d41b1111eb
commit 1afc2ed8e6
2 changed files with 485 additions and 160 deletions

View File

@ -5,12 +5,57 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from contextlib import contextmanager
import comfy.model_management
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.modules.attention import optimized_attention
from comfy_extras.nodes_seedvr import tiled_vae
import math
from enum import Enum
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
_NORM_LIMIT = float("inf")
def get_norm_limit():
return _NORM_LIMIT
def set_norm_limit(value: Optional[float] = None):
global _NORM_LIMIT
if value is None:
value = float("inf")
_NORM_LIMIT = value
@contextmanager
def ignore_padding(model):
orig_padding = model.padding
model.padding = (0, 0, 0)
try:
yield
finally:
model.padding = orig_padding
class MemoryState(Enum):
DISABLED = 0
INITIALIZING = 1
ACTIVE = 2
UNSET = 3
def get_cache_size(conv_module, input_len, pad_len, dim=0):
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
remain_len = (
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size)
)
overlap_len = dilated_kernerl_size - conv_module.stride[dim]
cache_len = overlap_len + remain_len # >= 0
assert output_len > 0
return cache_len
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
@ -34,6 +79,9 @@ class DiagonalGaussianDistribution(object):
x = self.mean + self.std * sample
return x
def mode(self):
return self.mean
class SpatialNorm(nn.Module):
def __init__(
self,
@ -366,41 +414,233 @@ def extend_head(tensor, times: int = 2, memory = None):
tile_repeat[2] = times
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
class InflatedCausalConv3d(nn.Conv3d):
def cache_send_recv(tensor, cache_size, times, memory=None):
# Single GPU inference - simplified cache handling
recv_buffer = None
# Handle memory buffer for single GPU case
if memory is not None:
recv_buffer = memory.to(tensor[0])
elif times > 0:
tile_repeat = [1] * tensor[0].ndim
tile_repeat[2] = times
recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat)
return recv_buffer
class InflatedCausalConv3d(torch.nn.Conv3d):
def __init__(
self,
*args,
inflation_mode,
memory_device = "same",
**kwargs,
):
self.inflation_mode = inflation_mode
self.memory = None
super().__init__(*args, **kwargs)
self.temporal_padding = self.padding[0]
self.memory_device = memory_device
self.padding = (0, *self.padding[1:])
self.memory_limit = float("inf")
def set_memory_limit(self, value: float):
self.memory_limit = value
def set_memory_device(self, memory_device):
self.memory_device = memory_device
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and
weight.dtype in (torch.float16, torch.bfloat16) and
hasattr(torch.backends.cudnn, 'is_available') and
torch.backends.cudnn.is_available() and
getattr(torch.backends.cudnn, 'enabled', True)):
try:
out = torch.cudnn_convolution(
input, weight, self.padding, self.stride, self.dilation, self.groups,
benchmark=False, deterministic=False, allow_tf32=True
)
if bias is not None:
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
return out
except RuntimeError:
pass
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def memory_limit_conv(
self,
x,
*,
split_dim=3,
padding=(0, 0, 0, 0, 0, 0),
prev_cache=None,
):
# Compatible with no limit.
if math.isinf(self.memory_limit):
if prev_cache is not None:
x = torch.cat([prev_cache, x], dim=split_dim - 1)
return super().forward(x)
# Compute tensor shape after concat & padding.
shape = torch.tensor(x.size())
if prev_cache is not None:
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0)
memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB
if memory_occupy < self.memory_limit or split_dim == x.ndim:
x_concat = x
if prev_cache is not None:
x_concat = torch.cat([prev_cache, x], dim=split_dim - 1)
def pad_and_forward():
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
with ignore_padding(self):
return torch.nn.Conv3d.forward(self, padded)
return pad_and_forward()
num_splits = math.ceil(memory_occupy / self.memory_limit)
size_per_split = x.size(split_dim) // num_splits
split_sizes = [size_per_split] * (num_splits - 1)
split_sizes += [x.size(split_dim) - sum(split_sizes)]
x = list(x.split(split_sizes, dim=split_dim))
if prev_cache is not None:
prev_cache = list(prev_cache.split(split_sizes, dim=split_dim))
cache = None
for idx in range(len(x)):
if prev_cache is not None:
x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1)
lpad_dim = (x[idx].ndim - split_dim - 1) * 2
rpad_dim = lpad_dim + 1
padding = list(padding)
padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0
padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0
pad_len = padding[lpad_dim] + padding[rpad_dim]
padding = tuple(padding)
next_cache = None
cache_len = cache.size(split_dim) if cache is not None else 0
next_catch_size = get_cache_size(
conv_module=self,
input_len=x[idx].size(split_dim) + cache_len,
pad_len=pad_len,
dim=split_dim - 2,
)
if next_catch_size != 0:
assert next_catch_size <= x[idx].size(split_dim)
next_cache = (
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim)
)
x[idx] = self.memory_limit_conv(
x[idx],
split_dim=split_dim + 1,
padding=padding,
prev_cache=cache
)
cache = next_cache
output = torch.cat(x, dim=split_dim)
return output
def forward(
self,
input,
):
input = extend_head(input, times=self.temporal_padding * 2)
memory_state: MemoryState = MemoryState.UNSET
) -> Tensor:
assert memory_state != MemoryState.UNSET
if memory_state != MemoryState.ACTIVE:
self.memory = None
if (
math.isinf(self.memory_limit)
and torch.is_tensor(input)
):
return self.basic_forward(input, memory_state)
return self.slicing_forward(input, memory_state)
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET):
mem_size = self.stride[0] - self.kernel_size[0]
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
input = extend_head(input, memory=self.memory, times=-1)
else:
input = extend_head(input, times=self.temporal_padding * 2)
memory = (
input[:, :, mem_size:].detach()
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
else None
)
if (
memory_state != MemoryState.DISABLED
and not self.training
and (self.memory_device is not None)
):
self.memory = memory
if self.memory_device == "cpu" and self.memory is not None:
self.memory = self.memory.to("cpu")
return super().forward(input)
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
def slicing_forward(
self,
input,
memory_state: MemoryState = MemoryState.UNSET,
) -> Tensor:
squeeze_out = False
if torch.is_tensor(input):
input = [input]
squeeze_out = True
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
cache_size = self.kernel_size[0] - self.stride[0]
cache = cache_send_recv(
input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2
)
# Single GPU inference - simplified memory management
if (
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing
and not self.training
and (self.memory_device is not None)
and cache_size != 0
):
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
input[0] = torch.cat([cache, input[0]], dim=2)
cache = None
if cache_size <= input[-1].size(2):
self.memory = input[-1][:, :, -cache_size:].detach().contiguous()
if self.memory_device == "cpu" and self.memory is not None:
self.memory = self.memory.to("cpu")
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
for i in range(len(input)):
# Prepare cache for next input slice.
next_cache = None
cache_size = 0
if i < len(input) - 1:
cache_len = cache.size(2) if cache is not None else 0
cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0)
if cache_size != 0:
if cache_size > input[i].size(2) and cache is not None:
input[i] = torch.cat([cache, input[i]], dim=2)
cache = None
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}"
next_cache = input[i][:, :, -cache_size:]
# Conv forward for this input slice.
input[i] = self.memory_limit_conv(
input[i],
padding=padding,
prev_cache=cache
)
# Update cache.
cache = next_cache
return input[0] if squeeze_out else input
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
if times == 0:
return tensor
@ -488,6 +728,7 @@ class Upsample3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
memory_state=None,
**kwargs,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
@ -517,7 +758,7 @@ class Upsample3D(nn.Module):
z=self.temporal_ratio,
)
if self.temporal_up:
if self.temporal_up and memory_state != MemoryState.ACTIVE:
hidden_states[0] = remove_head(hidden_states[0])
if not self.slicing:
@ -525,9 +766,9 @@ class Upsample3D(nn.Module):
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states, memory_state=memory_state)
else:
hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state)
if not self.slicing:
return hidden_states
@ -594,6 +835,7 @@ class Downsample3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
memory_state = None,
**kwargs,
) -> torch.FloatTensor:
@ -609,7 +851,7 @@ class Downsample3D(nn.Module):
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states, memory_state=memory_state)
return hidden_states
@ -707,7 +949,7 @@ class ResnetBlock3D(nn.Module):
)
def forward(
self, input_tensor, temb, **kwargs
self, input_tensor, temb, memory_state = None, **kwargs
):
hidden_states = input_tensor
@ -719,13 +961,13 @@ class ResnetBlock3D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
input_tensor = self.upsample(input_tensor, memory_state=memory_state)
hidden_states = self.upsample(hidden_states, memory_state=memory_state)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
input_tensor = self.downsample(input_tensor, memory_state=memory_state)
hidden_states = self.downsample(hidden_states, memory_state=memory_state)
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states, memory_state=memory_state)
if self.time_emb_proj is not None:
if not self.skip_time_act:
@ -740,10 +982,10 @@ class ResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states, memory_state=memory_state)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@ -819,15 +1061,16 @@ class DownEncoderBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
memory_state = None,
**kwargs,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
hidden_states = temporal(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, memory_state=memory_state)
return hidden_states
@ -907,14 +1150,15 @@ class UpDecoderBlock3D(nn.Module):
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
memory_state=None
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
hidden_states = temporal(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, memory_state=memory_state)
return hidden_states
@ -1008,9 +1252,9 @@ class UNetMidBlock3D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, memory_state=None):
video_length, frame_height, frame_width = hidden_states.size()[-3:]
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
@ -1018,7 +1262,7 @@ class UNetMidBlock3D(nn.Module):
hidden_states = rearrange(
hidden_states, "(b f) c h w -> b c f h w", f=video_length
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, memory_state=memory_state)
return hidden_states
@ -1136,10 +1380,11 @@ class Encoder3D(nn.Module):
self,
sample: torch.FloatTensor,
extra_cond=None,
memory_state = None
) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample)
sample = self.conv_in(sample, memory_state = memory_state)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@ -1164,17 +1409,17 @@ class Encoder3D(nn.Module):
# down
# [Override] add extra block and extra cond
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
sample = down_block(sample)
sample = down_block(sample, memory_state=memory_state)
if extra_block is not None:
sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:])
# middle
sample = self.mid_block(sample)
sample = self.mid_block(sample, memory_state=memory_state)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
sample = self.conv_out(sample, memory_state = memory_state)
return sample
@ -1282,74 +1527,90 @@ class Decoder3D(nn.Module):
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
memory_state = None,
) -> torch.FloatTensor:
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample)
sample = self.conv_in(sample, memory_state=memory_state)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
sample = self.mid_block(sample, latent_embeds)
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = up_block(sample, latent_embeds, memory_state=memory_state)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
sample = self.conv_out(sample, memory_state=memory_state)
return sample
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
def wavelet_decomposition(image: Tensor, levels: int = 5):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
del content_low_freq # Free memory immediately
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
class VideoAutoencoderKL(nn.Module):
def __init__(
@ -1368,9 +1629,12 @@ class VideoAutoencoderKL(nn.Module):
time_receptive_field: _receptive_field_t = "full",
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
slicing_sample_min_size = 4,
*args,
**kwargs,
):
self.slicing_sample_min_size = slicing_sample_min_size
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
block_out_channels = (128, 256, 512, 512)
down_block_types = ("DownEncoderBlock3D",) * 4
@ -1438,9 +1702,11 @@ class VideoAutoencoderKL(nn.Module):
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
self.use_slicing = True
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
h = self.slicing_encode(x)
posterior = DiagonalGaussianDistribution(h).sample()
posterior = DiagonalGaussianDistribution(h).mode()
if not return_dict:
return (posterior,)
@ -1458,30 +1724,72 @@ class VideoAutoencoderKL(nn.Module):
return decoded
def _encode(
self, x: torch.Tensor
self, x, memory_state
) -> torch.Tensor:
_x = x.to(self.device)
h = self.encoder(_x,)
h = self.encoder(_x, memory_state=memory_state)
if self.quant_conv is not None:
output = self.quant_conv(h)
output = self.quant_conv(h, memory_state=memory_state)
else:
output = h
return output.to(x.device)
def _decode(
self, z: torch.Tensor
self, z, memory_state
) -> torch.Tensor:
latent = z.to(self.device)
_z = z.to(self.device)
if self.post_quant_conv is not None:
latent = self.post_quant_conv(latent)
output = self.decoder(latent)
_z = self.post_quant_conv(_z, memory_state=memory_state)
output = self.decoder(_z, memory_state=memory_state)
return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
return self._encode(x)
sp_size =1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
encoded_slices = [
self._encode(
torch.cat((x[:, :, :1], x_slices[0]), dim=2),
memory_state=MemoryState.INITIALIZING,
)
]
for x_idx in range(1, len(x_slices)):
encoded_slices.append(
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE)
)
out = torch.cat(encoded_slices, dim=2)
modules_with_memory = [m for m in self.modules()
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
for m in modules_with_memory:
m.memory = None
return out
else:
return self._encode(x)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
return self._decode(z)
sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [
self._decode(
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
memory_state=MemoryState.INITIALIZING
)
]
for z_idx in range(1, len(z_slices)):
decoded_slices.append(
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE)
)
out = torch.cat(decoded_slices, dim=2)
modules_with_memory = [m for m in self.modules()
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
for m in modules_with_memory:
m.memory = None
return out
else:
return self._decode(z)
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
@ -1531,6 +1839,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
self.freeze_encoder = freeze_encoder
self.original_image_video = None
super().__init__(*args, **kwargs)
self.set_memory_limit(0.5, 0.5)
def forward(self, x: torch.FloatTensor):
with torch.no_grad() if self.freeze_encoder else nullcontext():
@ -1567,8 +1876,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
target_device = comfy.model_management.get_torch_device()
self.decoder.to(target_device)
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
#x = super().decode(latent).squeeze(2)
if self.tiled_args.get("enable_tiling", None) is not None:
self.enable_tiling = self.tiled_args.pop("enable_tiling", False)
if self.enable_tiling:
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
else:
x = super().decode_(latent).squeeze(2)
input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w")
if x.ndim == 4:
@ -1581,6 +1895,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
x = rearrange(x, "b c t h w -> (b t) c h w")
x = wavelet_reconstruction(x, input)
x = x.unsqueeze(0)
o_h, o_w = self.img_dims
x = x[..., :o_h, :o_w]
@ -1595,8 +1910,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return x
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
# TODO
#set_norm_limit(norm_max_mem)
set_norm_limit(norm_max_mem)
for m in self.modules():
if isinstance(m, InflatedCausalConv3d):
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))

View File

@ -14,25 +14,23 @@ from torchvision.transforms import Lambda, Normalize
from torchvision.transforms.functional import InterpolationMode
@torch.inference_mode()
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
gc.collect()
torch.cuda.empty_cache()
x = x.to(next(vae_model.parameters()).dtype)
if x.ndim != 5:
x = x.unsqueeze(2)
b, c, d, h, w = x.shape
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
if encode:
ti_h, ti_w = tile_size
ov_h, ov_w = tile_overlap
ti_t = temporal_size
ov_t = temporal_overlap
target_d = (d + sf_t - 1) // sf_t
target_h = (h + sf_s - 1) // sf_s
target_w = (w + sf_s - 1) // sf_s
@ -41,21 +39,44 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
ti_w = max(1, tile_size[1] // sf_s)
ov_h = max(0, tile_overlap[0] // sf_s)
ov_w = max(0, tile_overlap[1] // sf_s)
ti_t = max(1, temporal_size // sf_t)
ov_t = max(0, temporal_overlap // sf_t)
target_d = d * sf_t
target_h = h * sf_s
target_w = w * sf_s
stride_t = max(1, ti_t - ov_t)
stride_h = max(1, ti_h - ov_h)
stride_w = max(1, ti_w - ov_w)
storage_device = torch.device("cpu")
result = None
count = None
def run_temporal_chunks(spatial_tile):
chunk_results = []
t_dim_size = spatial_tile.shape[2]
if encode:
input_chunk = temporal_size
else:
input_chunk = max(1, temporal_size // sf_t)
for i in range(0, t_dim_size, input_chunk):
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
if encode:
out = vae_model.encode(t_chunk)
else:
out = vae_model.decode_(t_chunk)
if isinstance(out, (tuple, list)): out = out[0]
if out.ndim == 4: out = out.unsqueeze(2)
chunk_results.append(out.to(storage_device))
return torch.cat(chunk_results, dim=2)
ramp_cache = {}
def get_ramp(steps):
if steps not in ramp_cache:
@ -63,79 +84,64 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
return ramp_cache[steps]
bar = ProgressBar(d // stride_t)
for t_idx in range(0, d, stride_t):
t_end = min(t_idx + ti_t, d)
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
bar = ProgressBar(total_tiles)
for y_idx in range(0, h, stride_h):
y_end = min(y_idx + ti_h, h)
for y_idx in range(0, h, stride_h):
y_end = min(y_idx + ti_h, h)
for x_idx in range(0, w, stride_w):
x_end = min(x_idx + ti_w, w)
for x_idx in range(0, w, stride_w):
x_end = min(x_idx + ti_w, w)
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end]
# Run VAE
tile_out = run_temporal_chunks(tile_x)
if encode:
tile_out = vae_model.encode(tile_x)[0]
else:
tile_out = vae_model.decode_(tile_x)
if result is None:
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
if tile_out.ndim == 4:
tile_out = tile_out.unsqueeze(2)
if encode:
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
else:
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
tile_out = tile_out.to(storage_device).float()
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
if result is None:
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
if cur_ov_h > 0:
r = get_ramp(cur_ov_h)
if y_idx > 0: w_h[:cur_ov_h] = r
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
if encode:
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
if cur_ov_w > 0:
r = get_ramp(cur_ov_w)
if x_idx > 0: w_w[:cur_ov_w] = r
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
else:
ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2]
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2))
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
valid_d = min(tile_out.shape[2], result.shape[2])
tile_out = tile_out[:, :, :valid_d, :, :]
tile_out.mul_(final_weight)
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
count[:, :, :, ys:ye, xs:xe] += final_weight
w_t = torch.ones((tile_out.shape[2],), device=storage_device)
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
del tile_out, final_weight, tile_x, w_h, w_w
bar.update(1)
if cur_ov_t > 0:
r = get_ramp(cur_ov_t)
if t_idx > 0: w_t[:cur_ov_t] = r
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
if cur_ov_h > 0:
r = get_ramp(cur_ov_h)
if y_idx > 0: w_h[:cur_ov_h] = r
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
if cur_ov_w > 0:
r = get_ramp(cur_ov_w)
if x_idx > 0: w_w[:cur_ov_w] = r
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
tile_out.mul_(final_weight)
result[:, :, ts:te, ys:ye, xs:xe] += tile_out
count[:, :, ts:te, ys:ye, xs:xe] += final_weight
del tile_out, final_weight, tile_x, w_t, w_h, w_w
bar.update(1)
result.div_(count.clamp(min=1e-6))
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
@ -253,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
io.Int.Input("spatial_tile_size", default = 512, min = -1),
io.Int.Input("temporal_tile_size", default = 8, min = -1),
io.Int.Input("spatial_overlap", default = 64, min = -1),
io.Int.Input("temporal_overlap", default = 8, min = -1),
io.Boolean.Input("enable_tiling", default=False)
],
outputs = [
io.Latent.Output("vae_conditioning")
@ -261,7 +267,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
)
@classmethod
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
device = vae.patcher.load_device
offload_device = comfy.model_management.intermediate_device()
@ -296,9 +302,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
vae_model.original_image_video = images
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
"temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap}
"temporal_size":temporal_tile_size}
if enable_tiling:
latent = tiled_vae(images, vae_model, encode=True, **args)
else:
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
args["enable_tiling"] = enable_tiling
vae_model.tiled_args = args
latent = tiled_vae(images, vae_model, encode=True, **args)
vae_model = vae_model.to(offload_device)
vae_model.img_dims = [o_h, o_w]