mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
fixed the speed issue
This commit is contained in:
parent
d41b1111eb
commit
1afc2ed8e6
@ -5,12 +5,57 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy_extras.nodes_seedvr import tiled_vae
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
||||
|
||||
_NORM_LIMIT = float("inf")
|
||||
|
||||
|
||||
def get_norm_limit():
|
||||
return _NORM_LIMIT
|
||||
|
||||
|
||||
def set_norm_limit(value: Optional[float] = None):
|
||||
global _NORM_LIMIT
|
||||
if value is None:
|
||||
value = float("inf")
|
||||
_NORM_LIMIT = value
|
||||
|
||||
@contextmanager
|
||||
def ignore_padding(model):
|
||||
orig_padding = model.padding
|
||||
model.padding = (0, 0, 0)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.padding = orig_padding
|
||||
|
||||
class MemoryState(Enum):
|
||||
DISABLED = 0
|
||||
INITIALIZING = 1
|
||||
ACTIVE = 2
|
||||
UNSET = 3
|
||||
|
||||
def get_cache_size(conv_module, input_len, pad_len, dim=0):
|
||||
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
|
||||
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
|
||||
remain_len = (
|
||||
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size)
|
||||
)
|
||||
overlap_len = dilated_kernerl_size - conv_module.stride[dim]
|
||||
cache_len = overlap_len + remain_len # >= 0
|
||||
|
||||
assert output_len > 0
|
||||
return cache_len
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
@ -34,6 +79,9 @@ class DiagonalGaussianDistribution(object):
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class SpatialNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -366,41 +414,233 @@ def extend_head(tensor, times: int = 2, memory = None):
|
||||
tile_repeat[2] = times
|
||||
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
||||
|
||||
class InflatedCausalConv3d(nn.Conv3d):
|
||||
def cache_send_recv(tensor, cache_size, times, memory=None):
|
||||
# Single GPU inference - simplified cache handling
|
||||
recv_buffer = None
|
||||
|
||||
# Handle memory buffer for single GPU case
|
||||
if memory is not None:
|
||||
recv_buffer = memory.to(tensor[0])
|
||||
elif times > 0:
|
||||
tile_repeat = [1] * tensor[0].ndim
|
||||
tile_repeat[2] = times
|
||||
recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat)
|
||||
|
||||
return recv_buffer
|
||||
|
||||
class InflatedCausalConv3d(torch.nn.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
inflation_mode,
|
||||
memory_device = "same",
|
||||
**kwargs,
|
||||
):
|
||||
self.inflation_mode = inflation_mode
|
||||
self.memory = None
|
||||
super().__init__(*args, **kwargs)
|
||||
self.temporal_padding = self.padding[0]
|
||||
self.memory_device = memory_device
|
||||
self.padding = (0, *self.padding[1:])
|
||||
self.memory_limit = float("inf")
|
||||
|
||||
def set_memory_limit(self, value: float):
|
||||
self.memory_limit = value
|
||||
|
||||
def set_memory_device(self, memory_device):
|
||||
self.memory_device = memory_device
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and
|
||||
weight.dtype in (torch.float16, torch.bfloat16) and
|
||||
hasattr(torch.backends.cudnn, 'is_available') and
|
||||
torch.backends.cudnn.is_available() and
|
||||
getattr(torch.backends.cudnn, 'enabled', True)):
|
||||
try:
|
||||
out = torch.cudnn_convolution(
|
||||
input, weight, self.padding, self.stride, self.dilation, self.groups,
|
||||
benchmark=False, deterministic=False, allow_tf32=True
|
||||
)
|
||||
if bias is not None:
|
||||
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||
return out
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def memory_limit_conv(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
split_dim=3,
|
||||
padding=(0, 0, 0, 0, 0, 0),
|
||||
prev_cache=None,
|
||||
):
|
||||
# Compatible with no limit.
|
||||
if math.isinf(self.memory_limit):
|
||||
if prev_cache is not None:
|
||||
x = torch.cat([prev_cache, x], dim=split_dim - 1)
|
||||
return super().forward(x)
|
||||
|
||||
# Compute tensor shape after concat & padding.
|
||||
shape = torch.tensor(x.size())
|
||||
if prev_cache is not None:
|
||||
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
|
||||
shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0)
|
||||
memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB
|
||||
if memory_occupy < self.memory_limit or split_dim == x.ndim:
|
||||
x_concat = x
|
||||
if prev_cache is not None:
|
||||
x_concat = torch.cat([prev_cache, x], dim=split_dim - 1)
|
||||
|
||||
def pad_and_forward():
|
||||
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
|
||||
with ignore_padding(self):
|
||||
return torch.nn.Conv3d.forward(self, padded)
|
||||
|
||||
return pad_and_forward()
|
||||
|
||||
num_splits = math.ceil(memory_occupy / self.memory_limit)
|
||||
size_per_split = x.size(split_dim) // num_splits
|
||||
split_sizes = [size_per_split] * (num_splits - 1)
|
||||
split_sizes += [x.size(split_dim) - sum(split_sizes)]
|
||||
|
||||
x = list(x.split(split_sizes, dim=split_dim))
|
||||
if prev_cache is not None:
|
||||
prev_cache = list(prev_cache.split(split_sizes, dim=split_dim))
|
||||
cache = None
|
||||
for idx in range(len(x)):
|
||||
if prev_cache is not None:
|
||||
x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1)
|
||||
|
||||
lpad_dim = (x[idx].ndim - split_dim - 1) * 2
|
||||
rpad_dim = lpad_dim + 1
|
||||
padding = list(padding)
|
||||
padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0
|
||||
padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0
|
||||
pad_len = padding[lpad_dim] + padding[rpad_dim]
|
||||
padding = tuple(padding)
|
||||
|
||||
next_cache = None
|
||||
cache_len = cache.size(split_dim) if cache is not None else 0
|
||||
next_catch_size = get_cache_size(
|
||||
conv_module=self,
|
||||
input_len=x[idx].size(split_dim) + cache_len,
|
||||
pad_len=pad_len,
|
||||
dim=split_dim - 2,
|
||||
)
|
||||
if next_catch_size != 0:
|
||||
assert next_catch_size <= x[idx].size(split_dim)
|
||||
next_cache = (
|
||||
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim)
|
||||
)
|
||||
|
||||
x[idx] = self.memory_limit_conv(
|
||||
x[idx],
|
||||
split_dim=split_dim + 1,
|
||||
padding=padding,
|
||||
prev_cache=cache
|
||||
)
|
||||
|
||||
cache = next_cache
|
||||
|
||||
output = torch.cat(x, dim=split_dim)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
):
|
||||
input = extend_head(input, times=self.temporal_padding * 2)
|
||||
memory_state: MemoryState = MemoryState.UNSET
|
||||
) -> Tensor:
|
||||
assert memory_state != MemoryState.UNSET
|
||||
if memory_state != MemoryState.ACTIVE:
|
||||
self.memory = None
|
||||
if (
|
||||
math.isinf(self.memory_limit)
|
||||
and torch.is_tensor(input)
|
||||
):
|
||||
return self.basic_forward(input, memory_state)
|
||||
return self.slicing_forward(input, memory_state)
|
||||
|
||||
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET):
|
||||
mem_size = self.stride[0] - self.kernel_size[0]
|
||||
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
|
||||
input = extend_head(input, memory=self.memory, times=-1)
|
||||
else:
|
||||
input = extend_head(input, times=self.temporal_padding * 2)
|
||||
memory = (
|
||||
input[:, :, mem_size:].detach()
|
||||
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
|
||||
else None
|
||||
)
|
||||
if (
|
||||
memory_state != MemoryState.DISABLED
|
||||
and not self.training
|
||||
and (self.memory_device is not None)
|
||||
):
|
||||
self.memory = memory
|
||||
if self.memory_device == "cpu" and self.memory is not None:
|
||||
self.memory = self.memory.to("cpu")
|
||||
return super().forward(input)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
def slicing_forward(
|
||||
self,
|
||||
input,
|
||||
memory_state: MemoryState = MemoryState.UNSET,
|
||||
) -> Tensor:
|
||||
squeeze_out = False
|
||||
if torch.is_tensor(input):
|
||||
input = [input]
|
||||
squeeze_out = True
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
cache_size = self.kernel_size[0] - self.stride[0]
|
||||
cache = cache_send_recv(
|
||||
input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2
|
||||
)
|
||||
|
||||
# Single GPU inference - simplified memory management
|
||||
if (
|
||||
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing
|
||||
and not self.training
|
||||
and (self.memory_device is not None)
|
||||
and cache_size != 0
|
||||
):
|
||||
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
|
||||
input[0] = torch.cat([cache, input[0]], dim=2)
|
||||
cache = None
|
||||
if cache_size <= input[-1].size(2):
|
||||
self.memory = input[-1][:, :, -cache_size:].detach().contiguous()
|
||||
if self.memory_device == "cpu" and self.memory is not None:
|
||||
self.memory = self.memory.to("cpu")
|
||||
|
||||
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
|
||||
for i in range(len(input)):
|
||||
# Prepare cache for next input slice.
|
||||
next_cache = None
|
||||
cache_size = 0
|
||||
if i < len(input) - 1:
|
||||
cache_len = cache.size(2) if cache is not None else 0
|
||||
cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0)
|
||||
if cache_size != 0:
|
||||
if cache_size > input[i].size(2) and cache is not None:
|
||||
input[i] = torch.cat([cache, input[i]], dim=2)
|
||||
cache = None
|
||||
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}"
|
||||
next_cache = input[i][:, :, -cache_size:]
|
||||
|
||||
# Conv forward for this input slice.
|
||||
input[i] = self.memory_limit_conv(
|
||||
input[i],
|
||||
padding=padding,
|
||||
prev_cache=cache
|
||||
)
|
||||
|
||||
# Update cache.
|
||||
cache = next_cache
|
||||
|
||||
return input[0] if squeeze_out else input
|
||||
|
||||
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
|
||||
if times == 0:
|
||||
return tensor
|
||||
@ -488,6 +728,7 @@ class Upsample3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state=None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
@ -517,7 +758,7 @@ class Upsample3D(nn.Module):
|
||||
z=self.temporal_ratio,
|
||||
)
|
||||
|
||||
if self.temporal_up:
|
||||
if self.temporal_up and memory_state != MemoryState.ACTIVE:
|
||||
hidden_states[0] = remove_head(hidden_states[0])
|
||||
|
||||
if not self.slicing:
|
||||
@ -525,9 +766,9 @@ class Upsample3D(nn.Module):
|
||||
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state)
|
||||
|
||||
if not self.slicing:
|
||||
return hidden_states
|
||||
@ -594,6 +835,7 @@ class Downsample3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
@ -609,7 +851,7 @@ class Downsample3D(nn.Module):
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -707,7 +949,7 @@ class ResnetBlock3D(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_tensor, temb, **kwargs
|
||||
self, input_tensor, temb, memory_state = None, **kwargs
|
||||
):
|
||||
hidden_states = input_tensor
|
||||
|
||||
@ -719,13 +961,13 @@ class ResnetBlock3D(nn.Module):
|
||||
if hidden_states.shape[0] >= 64:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
input_tensor = self.upsample(input_tensor)
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
input_tensor = self.upsample(input_tensor, memory_state=memory_state)
|
||||
hidden_states = self.upsample(hidden_states, memory_state=memory_state)
|
||||
elif self.downsample is not None:
|
||||
input_tensor = self.downsample(input_tensor)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
input_tensor = self.downsample(input_tensor, memory_state=memory_state)
|
||||
hidden_states = self.downsample(hidden_states, memory_state=memory_state)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states, memory_state=memory_state)
|
||||
|
||||
if self.time_emb_proj is not None:
|
||||
if not self.skip_time_act:
|
||||
@ -740,10 +982,10 @@ class ResnetBlock3D(nn.Module):
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, memory_state=memory_state)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
@ -819,15 +1061,16 @@ class DownEncoderBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
|
||||
hidden_states = temporal(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, memory_state=memory_state)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -907,14 +1150,15 @@ class UpDecoderBlock3D(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
memory_state=None
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
|
||||
hidden_states = temporal(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
hidden_states = upsampler(hidden_states, memory_state=memory_state)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1008,9 +1252,9 @@ class UNetMidBlock3D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
def forward(self, hidden_states, temb=None, memory_state=None):
|
||||
video_length, frame_height, frame_width = hidden_states.size()[-3:]
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||
@ -1018,7 +1262,7 @@ class UNetMidBlock3D(nn.Module):
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b f) c h w -> b c f h w", f=video_length
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, memory_state=memory_state)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1136,10 +1380,11 @@ class Encoder3D(nn.Module):
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
extra_cond=None,
|
||||
memory_state = None
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample)
|
||||
sample = self.conv_in(sample, memory_state = memory_state)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@ -1164,17 +1409,17 @@ class Encoder3D(nn.Module):
|
||||
# down
|
||||
# [Override] add extra block and extra cond
|
||||
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
|
||||
sample = down_block(sample)
|
||||
sample = down_block(sample, memory_state=memory_state)
|
||||
if extra_block is not None:
|
||||
sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:])
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
sample = self.mid_block(sample, memory_state=memory_state)
|
||||
|
||||
# post-process
|
||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
sample = self.conv_out(sample, memory_state = memory_state)
|
||||
|
||||
return sample
|
||||
|
||||
@ -1282,74 +1527,90 @@ class Decoder3D(nn.Module):
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||
memory_state = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample)
|
||||
sample = self.conv_in(sample, memory_state=memory_state)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(sample, latent_embeds)
|
||||
sample = up_block(sample, latent_embeds, memory_state=memory_state)
|
||||
|
||||
# post-process
|
||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
sample = self.conv_out(sample, memory_state=memory_state)
|
||||
|
||||
return sample
|
||||
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
def wavelet_blur(image: Tensor, radius):
|
||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||
if radius > max_safe_radius:
|
||||
radius = max_safe_radius
|
||||
|
||||
num_channels = image.shape[1]
|
||||
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
def wavelet_decomposition(image: Tensor, levels: int = 5):
|
||||
high_freq = torch.zeros_like(image)
|
||||
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
high_freq.add_(image).sub_(low_freq)
|
||||
image = low_freq
|
||||
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
del content_low_freq # Free memory immediately
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
content_high_freq.add_(style_low_freq)
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
@ -1368,9 +1629,12 @@ class VideoAutoencoderKL(nn.Module):
|
||||
time_receptive_field: _receptive_field_t = "full",
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
slicing_sample_min_size = 4,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.slicing_sample_min_size = slicing_sample_min_size
|
||||
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
|
||||
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
||||
block_out_channels = (128, 256, 512, 512)
|
||||
down_block_types = ("DownEncoderBlock3D",) * 4
|
||||
@ -1438,9 +1702,11 @@ class VideoAutoencoderKL(nn.Module):
|
||||
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
|
||||
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
|
||||
|
||||
self.use_slicing = True
|
||||
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
||||
h = self.slicing_encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h).sample()
|
||||
posterior = DiagonalGaussianDistribution(h).mode()
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
@ -1458,30 +1724,72 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return decoded
|
||||
|
||||
def _encode(
|
||||
self, x: torch.Tensor
|
||||
self, x, memory_state
|
||||
) -> torch.Tensor:
|
||||
_x = x.to(self.device)
|
||||
h = self.encoder(_x,)
|
||||
h = self.encoder(_x, memory_state=memory_state)
|
||||
if self.quant_conv is not None:
|
||||
output = self.quant_conv(h)
|
||||
output = self.quant_conv(h, memory_state=memory_state)
|
||||
else:
|
||||
output = h
|
||||
return output.to(x.device)
|
||||
|
||||
def _decode(
|
||||
self, z: torch.Tensor
|
||||
self, z, memory_state
|
||||
) -> torch.Tensor:
|
||||
latent = z.to(self.device)
|
||||
_z = z.to(self.device)
|
||||
|
||||
if self.post_quant_conv is not None:
|
||||
latent = self.post_quant_conv(latent)
|
||||
output = self.decoder(latent)
|
||||
_z = self.post_quant_conv(_z, memory_state=memory_state)
|
||||
|
||||
output = self.decoder(_z, memory_state=memory_state)
|
||||
return output.to(z.device)
|
||||
|
||||
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self._encode(x)
|
||||
sp_size =1
|
||||
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
|
||||
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
|
||||
encoded_slices = [
|
||||
self._encode(
|
||||
torch.cat((x[:, :, :1], x_slices[0]), dim=2),
|
||||
memory_state=MemoryState.INITIALIZING,
|
||||
)
|
||||
]
|
||||
for x_idx in range(1, len(x_slices)):
|
||||
encoded_slices.append(
|
||||
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE)
|
||||
)
|
||||
out = torch.cat(encoded_slices, dim=2)
|
||||
modules_with_memory = [m for m in self.modules()
|
||||
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
|
||||
for m in modules_with_memory:
|
||||
m.memory = None
|
||||
return out
|
||||
else:
|
||||
return self._encode(x)
|
||||
|
||||
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self._decode(z)
|
||||
sp_size = 1
|
||||
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
|
||||
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
|
||||
decoded_slices = [
|
||||
self._decode(
|
||||
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
|
||||
memory_state=MemoryState.INITIALIZING
|
||||
)
|
||||
]
|
||||
for z_idx in range(1, len(z_slices)):
|
||||
decoded_slices.append(
|
||||
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE)
|
||||
)
|
||||
out = torch.cat(decoded_slices, dim=2)
|
||||
modules_with_memory = [m for m in self.modules()
|
||||
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
|
||||
for m in modules_with_memory:
|
||||
m.memory = None
|
||||
return out
|
||||
else:
|
||||
return self._decode(z)
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@ -1531,6 +1839,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
self.freeze_encoder = freeze_encoder
|
||||
self.original_image_video = None
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_memory_limit(0.5, 0.5)
|
||||
|
||||
def forward(self, x: torch.FloatTensor):
|
||||
with torch.no_grad() if self.freeze_encoder else nullcontext():
|
||||
@ -1567,8 +1876,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
|
||||
target_device = comfy.model_management.get_torch_device()
|
||||
self.decoder.to(target_device)
|
||||
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
|
||||
#x = super().decode(latent).squeeze(2)
|
||||
if self.tiled_args.get("enable_tiling", None) is not None:
|
||||
self.enable_tiling = self.tiled_args.pop("enable_tiling", False)
|
||||
|
||||
if self.enable_tiling:
|
||||
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
|
||||
else:
|
||||
x = super().decode_(latent).squeeze(2)
|
||||
|
||||
input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w")
|
||||
if x.ndim == 4:
|
||||
@ -1581,6 +1895,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
x = wavelet_reconstruction(x, input)
|
||||
|
||||
x = x.unsqueeze(0)
|
||||
o_h, o_w = self.img_dims
|
||||
x = x[..., :o_h, :o_w]
|
||||
@ -1595,8 +1910,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
return x
|
||||
|
||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
||||
# TODO
|
||||
#set_norm_limit(norm_max_mem)
|
||||
set_norm_limit(norm_max_mem)
|
||||
for m in self.modules():
|
||||
if isinstance(m, InflatedCausalConv3d):
|
||||
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|
||||
|
||||
@ -14,25 +14,23 @@ from torchvision.transforms import Lambda, Normalize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
|
||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
x = x.to(next(vae_model.parameters()).dtype)
|
||||
if x.ndim != 5:
|
||||
x = x.unsqueeze(2)
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
|
||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||
|
||||
if encode:
|
||||
ti_h, ti_w = tile_size
|
||||
ov_h, ov_w = tile_overlap
|
||||
ti_t = temporal_size
|
||||
ov_t = temporal_overlap
|
||||
|
||||
target_d = (d + sf_t - 1) // sf_t
|
||||
target_h = (h + sf_s - 1) // sf_s
|
||||
target_w = (w + sf_s - 1) // sf_s
|
||||
@ -41,21 +39,44 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
ti_w = max(1, tile_size[1] // sf_s)
|
||||
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||
ti_t = max(1, temporal_size // sf_t)
|
||||
ov_t = max(0, temporal_overlap // sf_t)
|
||||
|
||||
|
||||
target_d = d * sf_t
|
||||
target_h = h * sf_s
|
||||
target_w = w * sf_s
|
||||
|
||||
stride_t = max(1, ti_t - ov_t)
|
||||
stride_h = max(1, ti_h - ov_h)
|
||||
stride_w = max(1, ti_w - ov_w)
|
||||
|
||||
storage_device = torch.device("cpu")
|
||||
|
||||
result = None
|
||||
count = None
|
||||
|
||||
def run_temporal_chunks(spatial_tile):
|
||||
chunk_results = []
|
||||
t_dim_size = spatial_tile.shape[2]
|
||||
|
||||
if encode:
|
||||
input_chunk = temporal_size
|
||||
else:
|
||||
input_chunk = max(1, temporal_size // sf_t)
|
||||
|
||||
for i in range(0, t_dim_size, input_chunk):
|
||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||
|
||||
if encode:
|
||||
out = vae_model.encode(t_chunk)
|
||||
else:
|
||||
out = vae_model.decode_(t_chunk)
|
||||
|
||||
if isinstance(out, (tuple, list)): out = out[0]
|
||||
|
||||
if out.ndim == 4: out = out.unsqueeze(2)
|
||||
|
||||
chunk_results.append(out.to(storage_device))
|
||||
|
||||
return torch.cat(chunk_results, dim=2)
|
||||
|
||||
ramp_cache = {}
|
||||
def get_ramp(steps):
|
||||
if steps not in ramp_cache:
|
||||
@ -63,79 +84,64 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||
return ramp_cache[steps]
|
||||
|
||||
bar = ProgressBar(d // stride_t)
|
||||
for t_idx in range(0, d, stride_t):
|
||||
t_end = min(t_idx + ti_t, d)
|
||||
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
||||
bar = ProgressBar(total_tiles)
|
||||
|
||||
for y_idx in range(0, h, stride_h):
|
||||
y_end = min(y_idx + ti_h, h)
|
||||
for y_idx in range(0, h, stride_h):
|
||||
y_end = min(y_idx + ti_h, h)
|
||||
|
||||
for x_idx in range(0, w, stride_w):
|
||||
x_end = min(x_idx + ti_w, w)
|
||||
|
||||
for x_idx in range(0, w, stride_w):
|
||||
x_end = min(x_idx + ti_w, w)
|
||||
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||
|
||||
tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end]
|
||||
# Run VAE
|
||||
tile_out = run_temporal_chunks(tile_x)
|
||||
|
||||
if encode:
|
||||
tile_out = vae_model.encode(tile_x)[0]
|
||||
else:
|
||||
tile_out = vae_model.decode_(tile_x)
|
||||
if result is None:
|
||||
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
|
||||
if tile_out.ndim == 4:
|
||||
tile_out = tile_out.unsqueeze(2)
|
||||
if encode:
|
||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||
else:
|
||||
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||
|
||||
tile_out = tile_out.to(storage_device).float()
|
||||
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||
|
||||
if result is None:
|
||||
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
|
||||
|
||||
if encode:
|
||||
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
|
||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||
if cur_ov_w > 0:
|
||||
r = get_ramp(cur_ov_w)
|
||||
if x_idx > 0: w_w[:cur_ov_w] = r
|
||||
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
||||
|
||||
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
|
||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||
else:
|
||||
ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2]
|
||||
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||
|
||||
cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2))
|
||||
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||
|
||||
tile_out.mul_(final_weight)
|
||||
|
||||
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||
|
||||
w_t = torch.ones((tile_out.shape[2],), device=storage_device)
|
||||
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||
del tile_out, final_weight, tile_x, w_h, w_w
|
||||
bar.update(1)
|
||||
|
||||
if cur_ov_t > 0:
|
||||
r = get_ramp(cur_ov_t)
|
||||
if t_idx > 0: w_t[:cur_ov_t] = r
|
||||
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
|
||||
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
|
||||
|
||||
if cur_ov_w > 0:
|
||||
r = get_ramp(cur_ov_w)
|
||||
if x_idx > 0: w_w[:cur_ov_w] = r
|
||||
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
||||
|
||||
final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||
|
||||
tile_out.mul_(final_weight)
|
||||
result[:, :, ts:te, ys:ye, xs:xe] += tile_out
|
||||
count[:, :, ts:te, ys:ye, xs:xe] += final_weight
|
||||
|
||||
del tile_out, final_weight, tile_x, w_t, w_h, w_w
|
||||
bar.update(1)
|
||||
result.div_(count.clamp(min=1e-6))
|
||||
|
||||
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
|
||||
@ -253,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
io.Int.Input("spatial_tile_size", default = 512, min = -1),
|
||||
io.Int.Input("temporal_tile_size", default = 8, min = -1),
|
||||
io.Int.Input("spatial_overlap", default = 64, min = -1),
|
||||
io.Int.Input("temporal_overlap", default = 8, min = -1),
|
||||
io.Boolean.Input("enable_tiling", default=False)
|
||||
],
|
||||
outputs = [
|
||||
io.Latent.Output("vae_conditioning")
|
||||
@ -261,7 +267,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
|
||||
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||
device = vae.patcher.load_device
|
||||
|
||||
offload_device = comfy.model_management.intermediate_device()
|
||||
@ -296,9 +302,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
vae_model.original_image_video = images
|
||||
|
||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||
"temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap}
|
||||
"temporal_size":temporal_tile_size}
|
||||
if enable_tiling:
|
||||
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||
else:
|
||||
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
||||
|
||||
args["enable_tiling"] = enable_tiling
|
||||
vae_model.tiled_args = args
|
||||
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||
|
||||
vae_model = vae_model.to(offload_device)
|
||||
vae_model.img_dims = [o_h, o_w]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user