mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 22:00:49 +08:00
fixed the speed issue
This commit is contained in:
parent
d41b1111eb
commit
1afc2ed8e6
@ -5,12 +5,57 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy_extras.nodes_seedvr import tiled_vae
|
from comfy_extras.nodes_seedvr import tiled_vae
|
||||||
|
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
||||||
|
|
||||||
|
_NORM_LIMIT = float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_limit():
|
||||||
|
return _NORM_LIMIT
|
||||||
|
|
||||||
|
|
||||||
|
def set_norm_limit(value: Optional[float] = None):
|
||||||
|
global _NORM_LIMIT
|
||||||
|
if value is None:
|
||||||
|
value = float("inf")
|
||||||
|
_NORM_LIMIT = value
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ignore_padding(model):
|
||||||
|
orig_padding = model.padding
|
||||||
|
model.padding = (0, 0, 0)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
model.padding = orig_padding
|
||||||
|
|
||||||
|
class MemoryState(Enum):
|
||||||
|
DISABLED = 0
|
||||||
|
INITIALIZING = 1
|
||||||
|
ACTIVE = 2
|
||||||
|
UNSET = 3
|
||||||
|
|
||||||
|
def get_cache_size(conv_module, input_len, pad_len, dim=0):
|
||||||
|
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
|
||||||
|
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
|
||||||
|
remain_len = (
|
||||||
|
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size)
|
||||||
|
)
|
||||||
|
overlap_len = dilated_kernerl_size - conv_module.stride[dim]
|
||||||
|
cache_len = overlap_len + remain_len # >= 0
|
||||||
|
|
||||||
|
assert output_len > 0
|
||||||
|
return cache_len
|
||||||
|
|
||||||
class DiagonalGaussianDistribution(object):
|
class DiagonalGaussianDistribution(object):
|
||||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
@ -34,6 +79,9 @@ class DiagonalGaussianDistribution(object):
|
|||||||
x = self.mean + self.std * sample
|
x = self.mean + self.std * sample
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
class SpatialNorm(nn.Module):
|
class SpatialNorm(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -366,41 +414,233 @@ def extend_head(tensor, times: int = 2, memory = None):
|
|||||||
tile_repeat[2] = times
|
tile_repeat[2] = times
|
||||||
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
||||||
|
|
||||||
class InflatedCausalConv3d(nn.Conv3d):
|
def cache_send_recv(tensor, cache_size, times, memory=None):
|
||||||
|
# Single GPU inference - simplified cache handling
|
||||||
|
recv_buffer = None
|
||||||
|
|
||||||
|
# Handle memory buffer for single GPU case
|
||||||
|
if memory is not None:
|
||||||
|
recv_buffer = memory.to(tensor[0])
|
||||||
|
elif times > 0:
|
||||||
|
tile_repeat = [1] * tensor[0].ndim
|
||||||
|
tile_repeat[2] = times
|
||||||
|
recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat)
|
||||||
|
|
||||||
|
return recv_buffer
|
||||||
|
|
||||||
|
class InflatedCausalConv3d(torch.nn.Conv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
inflation_mode,
|
inflation_mode,
|
||||||
|
memory_device = "same",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.inflation_mode = inflation_mode
|
self.inflation_mode = inflation_mode
|
||||||
self.memory = None
|
self.memory = None
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.temporal_padding = self.padding[0]
|
self.temporal_padding = self.padding[0]
|
||||||
|
self.memory_device = memory_device
|
||||||
self.padding = (0, *self.padding[1:])
|
self.padding = (0, *self.padding[1:])
|
||||||
self.memory_limit = float("inf")
|
self.memory_limit = float("inf")
|
||||||
|
|
||||||
|
def set_memory_limit(self, value: float):
|
||||||
|
self.memory_limit = value
|
||||||
|
|
||||||
|
def set_memory_device(self, memory_device):
|
||||||
|
self.memory_device = memory_device
|
||||||
|
|
||||||
|
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||||
|
if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and
|
||||||
|
weight.dtype in (torch.float16, torch.bfloat16) and
|
||||||
|
hasattr(torch.backends.cudnn, 'is_available') and
|
||||||
|
torch.backends.cudnn.is_available() and
|
||||||
|
getattr(torch.backends.cudnn, 'enabled', True)):
|
||||||
|
try:
|
||||||
|
out = torch.cudnn_convolution(
|
||||||
|
input, weight, self.padding, self.stride, self.dilation, self.groups,
|
||||||
|
benchmark=False, deterministic=False, allow_tf32=True
|
||||||
|
)
|
||||||
|
if bias is not None:
|
||||||
|
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||||
|
return out
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||||
|
|
||||||
|
def memory_limit_conv(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
*,
|
||||||
|
split_dim=3,
|
||||||
|
padding=(0, 0, 0, 0, 0, 0),
|
||||||
|
prev_cache=None,
|
||||||
|
):
|
||||||
|
# Compatible with no limit.
|
||||||
|
if math.isinf(self.memory_limit):
|
||||||
|
if prev_cache is not None:
|
||||||
|
x = torch.cat([prev_cache, x], dim=split_dim - 1)
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
# Compute tensor shape after concat & padding.
|
||||||
|
shape = torch.tensor(x.size())
|
||||||
|
if prev_cache is not None:
|
||||||
|
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
|
||||||
|
shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0)
|
||||||
|
memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB
|
||||||
|
if memory_occupy < self.memory_limit or split_dim == x.ndim:
|
||||||
|
x_concat = x
|
||||||
|
if prev_cache is not None:
|
||||||
|
x_concat = torch.cat([prev_cache, x], dim=split_dim - 1)
|
||||||
|
|
||||||
|
def pad_and_forward():
|
||||||
|
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
|
||||||
|
with ignore_padding(self):
|
||||||
|
return torch.nn.Conv3d.forward(self, padded)
|
||||||
|
|
||||||
|
return pad_and_forward()
|
||||||
|
|
||||||
|
num_splits = math.ceil(memory_occupy / self.memory_limit)
|
||||||
|
size_per_split = x.size(split_dim) // num_splits
|
||||||
|
split_sizes = [size_per_split] * (num_splits - 1)
|
||||||
|
split_sizes += [x.size(split_dim) - sum(split_sizes)]
|
||||||
|
|
||||||
|
x = list(x.split(split_sizes, dim=split_dim))
|
||||||
|
if prev_cache is not None:
|
||||||
|
prev_cache = list(prev_cache.split(split_sizes, dim=split_dim))
|
||||||
|
cache = None
|
||||||
|
for idx in range(len(x)):
|
||||||
|
if prev_cache is not None:
|
||||||
|
x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1)
|
||||||
|
|
||||||
|
lpad_dim = (x[idx].ndim - split_dim - 1) * 2
|
||||||
|
rpad_dim = lpad_dim + 1
|
||||||
|
padding = list(padding)
|
||||||
|
padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0
|
||||||
|
padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0
|
||||||
|
pad_len = padding[lpad_dim] + padding[rpad_dim]
|
||||||
|
padding = tuple(padding)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
cache_len = cache.size(split_dim) if cache is not None else 0
|
||||||
|
next_catch_size = get_cache_size(
|
||||||
|
conv_module=self,
|
||||||
|
input_len=x[idx].size(split_dim) + cache_len,
|
||||||
|
pad_len=pad_len,
|
||||||
|
dim=split_dim - 2,
|
||||||
|
)
|
||||||
|
if next_catch_size != 0:
|
||||||
|
assert next_catch_size <= x[idx].size(split_dim)
|
||||||
|
next_cache = (
|
||||||
|
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
x[idx] = self.memory_limit_conv(
|
||||||
|
x[idx],
|
||||||
|
split_dim=split_dim + 1,
|
||||||
|
padding=padding,
|
||||||
|
prev_cache=cache
|
||||||
|
)
|
||||||
|
|
||||||
|
cache = next_cache
|
||||||
|
|
||||||
|
output = torch.cat(x, dim=split_dim)
|
||||||
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input,
|
input,
|
||||||
|
memory_state: MemoryState = MemoryState.UNSET
|
||||||
|
) -> Tensor:
|
||||||
|
assert memory_state != MemoryState.UNSET
|
||||||
|
if memory_state != MemoryState.ACTIVE:
|
||||||
|
self.memory = None
|
||||||
|
if (
|
||||||
|
math.isinf(self.memory_limit)
|
||||||
|
and torch.is_tensor(input)
|
||||||
):
|
):
|
||||||
|
return self.basic_forward(input, memory_state)
|
||||||
|
return self.slicing_forward(input, memory_state)
|
||||||
|
|
||||||
|
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET):
|
||||||
|
mem_size = self.stride[0] - self.kernel_size[0]
|
||||||
|
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
|
||||||
|
input = extend_head(input, memory=self.memory, times=-1)
|
||||||
|
else:
|
||||||
input = extend_head(input, times=self.temporal_padding * 2)
|
input = extend_head(input, times=self.temporal_padding * 2)
|
||||||
|
memory = (
|
||||||
|
input[:, :, mem_size:].detach()
|
||||||
|
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
memory_state != MemoryState.DISABLED
|
||||||
|
and not self.training
|
||||||
|
and (self.memory_device is not None)
|
||||||
|
):
|
||||||
|
self.memory = memory
|
||||||
|
if self.memory_device == "cpu" and self.memory is not None:
|
||||||
|
self.memory = self.memory.to("cpu")
|
||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
def _load_from_state_dict(
|
def slicing_forward(
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
self,
|
||||||
):
|
input,
|
||||||
|
memory_state: MemoryState = MemoryState.UNSET,
|
||||||
|
) -> Tensor:
|
||||||
|
squeeze_out = False
|
||||||
|
if torch.is_tensor(input):
|
||||||
|
input = [input]
|
||||||
|
squeeze_out = True
|
||||||
|
|
||||||
super()._load_from_state_dict(
|
cache_size = self.kernel_size[0] - self.stride[0]
|
||||||
state_dict,
|
cache = cache_send_recv(
|
||||||
prefix,
|
input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2
|
||||||
local_metadata,
|
|
||||||
strict,
|
|
||||||
missing_keys,
|
|
||||||
unexpected_keys,
|
|
||||||
error_msgs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Single GPU inference - simplified memory management
|
||||||
|
if (
|
||||||
|
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing
|
||||||
|
and not self.training
|
||||||
|
and (self.memory_device is not None)
|
||||||
|
and cache_size != 0
|
||||||
|
):
|
||||||
|
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
|
||||||
|
input[0] = torch.cat([cache, input[0]], dim=2)
|
||||||
|
cache = None
|
||||||
|
if cache_size <= input[-1].size(2):
|
||||||
|
self.memory = input[-1][:, :, -cache_size:].detach().contiguous()
|
||||||
|
if self.memory_device == "cpu" and self.memory is not None:
|
||||||
|
self.memory = self.memory.to("cpu")
|
||||||
|
|
||||||
|
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
|
||||||
|
for i in range(len(input)):
|
||||||
|
# Prepare cache for next input slice.
|
||||||
|
next_cache = None
|
||||||
|
cache_size = 0
|
||||||
|
if i < len(input) - 1:
|
||||||
|
cache_len = cache.size(2) if cache is not None else 0
|
||||||
|
cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0)
|
||||||
|
if cache_size != 0:
|
||||||
|
if cache_size > input[i].size(2) and cache is not None:
|
||||||
|
input[i] = torch.cat([cache, input[i]], dim=2)
|
||||||
|
cache = None
|
||||||
|
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}"
|
||||||
|
next_cache = input[i][:, :, -cache_size:]
|
||||||
|
|
||||||
|
# Conv forward for this input slice.
|
||||||
|
input[i] = self.memory_limit_conv(
|
||||||
|
input[i],
|
||||||
|
padding=padding,
|
||||||
|
prev_cache=cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cache.
|
||||||
|
cache = next_cache
|
||||||
|
|
||||||
|
return input[0] if squeeze_out else input
|
||||||
|
|
||||||
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
|
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
|
||||||
if times == 0:
|
if times == 0:
|
||||||
return tensor
|
return tensor
|
||||||
@ -488,6 +728,7 @@ class Upsample3D(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
|
memory_state=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
assert hidden_states.shape[1] == self.channels
|
assert hidden_states.shape[1] == self.channels
|
||||||
@ -517,7 +758,7 @@ class Upsample3D(nn.Module):
|
|||||||
z=self.temporal_ratio,
|
z=self.temporal_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.temporal_up:
|
if self.temporal_up and memory_state != MemoryState.ACTIVE:
|
||||||
hidden_states[0] = remove_head(hidden_states[0])
|
hidden_states[0] = remove_head(hidden_states[0])
|
||||||
|
|
||||||
if not self.slicing:
|
if not self.slicing:
|
||||||
@ -525,9 +766,9 @@ class Upsample3D(nn.Module):
|
|||||||
|
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
if self.name == "conv":
|
if self.name == "conv":
|
||||||
hidden_states = self.conv(hidden_states)
|
hidden_states = self.conv(hidden_states, memory_state=memory_state)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.Conv2d_0(hidden_states)
|
hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
if not self.slicing:
|
if not self.slicing:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -594,6 +835,7 @@ class Downsample3D(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
|
memory_state = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
@ -609,7 +851,7 @@ class Downsample3D(nn.Module):
|
|||||||
|
|
||||||
assert hidden_states.shape[1] == self.channels
|
assert hidden_states.shape[1] == self.channels
|
||||||
|
|
||||||
hidden_states = self.conv(hidden_states)
|
hidden_states = self.conv(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -707,7 +949,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_tensor, temb, **kwargs
|
self, input_tensor, temb, memory_state = None, **kwargs
|
||||||
):
|
):
|
||||||
hidden_states = input_tensor
|
hidden_states = input_tensor
|
||||||
|
|
||||||
@ -719,13 +961,13 @@ class ResnetBlock3D(nn.Module):
|
|||||||
if hidden_states.shape[0] >= 64:
|
if hidden_states.shape[0] >= 64:
|
||||||
input_tensor = input_tensor.contiguous()
|
input_tensor = input_tensor.contiguous()
|
||||||
hidden_states = hidden_states.contiguous()
|
hidden_states = hidden_states.contiguous()
|
||||||
input_tensor = self.upsample(input_tensor)
|
input_tensor = self.upsample(input_tensor, memory_state=memory_state)
|
||||||
hidden_states = self.upsample(hidden_states)
|
hidden_states = self.upsample(hidden_states, memory_state=memory_state)
|
||||||
elif self.downsample is not None:
|
elif self.downsample is not None:
|
||||||
input_tensor = self.downsample(input_tensor)
|
input_tensor = self.downsample(input_tensor, memory_state=memory_state)
|
||||||
hidden_states = self.downsample(hidden_states)
|
hidden_states = self.downsample(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
hidden_states = self.conv1(hidden_states)
|
hidden_states = self.conv1(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
if self.time_emb_proj is not None:
|
if self.time_emb_proj is not None:
|
||||||
if not self.skip_time_act:
|
if not self.skip_time_act:
|
||||||
@ -740,10 +982,10 @@ class ResnetBlock3D(nn.Module):
|
|||||||
hidden_states = self.nonlinearity(hidden_states)
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.conv2(hidden_states)
|
hidden_states = self.conv2(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
if self.conv_shortcut is not None:
|
if self.conv_shortcut is not None:
|
||||||
input_tensor = self.conv_shortcut(input_tensor)
|
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state)
|
||||||
|
|
||||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
|
||||||
@ -819,15 +1061,16 @@ class DownEncoderBlock3D(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
|
memory_state = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||||
hidden_states = resnet(hidden_states, temb=None)
|
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
|
||||||
hidden_states = temporal(hidden_states)
|
hidden_states = temporal(hidden_states)
|
||||||
|
|
||||||
if self.downsamplers is not None:
|
if self.downsamplers is not None:
|
||||||
for downsampler in self.downsamplers:
|
for downsampler in self.downsamplers:
|
||||||
hidden_states = downsampler(hidden_states)
|
hidden_states = downsampler(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -907,14 +1150,15 @@ class UpDecoderBlock3D(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
memory_state=None
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||||
hidden_states = resnet(hidden_states, temb=None)
|
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
|
||||||
hidden_states = temporal(hidden_states)
|
hidden_states = temporal(hidden_states)
|
||||||
|
|
||||||
if self.upsamplers is not None:
|
if self.upsamplers is not None:
|
||||||
for upsampler in self.upsamplers:
|
for upsampler in self.upsamplers:
|
||||||
hidden_states = upsampler(hidden_states)
|
hidden_states = upsampler(hidden_states, memory_state=memory_state)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -1008,9 +1252,9 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
def forward(self, hidden_states, temb=None):
|
def forward(self, hidden_states, temb=None, memory_state=None):
|
||||||
video_length, frame_height, frame_width = hidden_states.size()[-3:]
|
video_length, frame_height, frame_width = hidden_states.size()[-3:]
|
||||||
hidden_states = self.resnets[0](hidden_states, temb)
|
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state)
|
||||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
if attn is not None:
|
if attn is not None:
|
||||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||||
@ -1018,7 +1262,7 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
hidden_states = rearrange(
|
hidden_states = rearrange(
|
||||||
hidden_states, "(b f) c h w -> b c f h w", f=video_length
|
hidden_states, "(b f) c h w -> b c f h w", f=video_length
|
||||||
)
|
)
|
||||||
hidden_states = resnet(hidden_states, temb)
|
hidden_states = resnet(hidden_states, temb, memory_state=memory_state)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -1136,10 +1380,11 @@ class Encoder3D(nn.Module):
|
|||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
extra_cond=None,
|
extra_cond=None,
|
||||||
|
memory_state = None
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Encoder` class."""
|
r"""The forward method of the `Encoder` class."""
|
||||||
sample = sample.to(next(self.parameters()).device)
|
sample = sample.to(next(self.parameters()).device)
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample, memory_state = memory_state)
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
@ -1164,17 +1409,17 @@ class Encoder3D(nn.Module):
|
|||||||
# down
|
# down
|
||||||
# [Override] add extra block and extra cond
|
# [Override] add extra block and extra cond
|
||||||
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
|
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
|
||||||
sample = down_block(sample)
|
sample = down_block(sample, memory_state=memory_state)
|
||||||
if extra_block is not None:
|
if extra_block is not None:
|
||||||
sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:])
|
sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:])
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
sample = self.mid_block(sample)
|
sample = self.mid_block(sample, memory_state=memory_state)
|
||||||
|
|
||||||
# post-process
|
# post-process
|
||||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample)
|
sample = self.conv_out(sample, memory_state = memory_state)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@ -1282,74 +1527,90 @@ class Decoder3D(nn.Module):
|
|||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
memory_state = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
sample = sample.to(next(self.parameters()).device)
|
sample = sample.to(next(self.parameters()).device)
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample, memory_state=memory_state)
|
||||||
|
|
||||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||||
# middle
|
# middle
|
||||||
sample = self.mid_block(sample, latent_embeds)
|
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
|
||||||
sample = sample.to(upscale_dtype)
|
sample = sample.to(upscale_dtype)
|
||||||
|
|
||||||
# up
|
# up
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
sample = up_block(sample, latent_embeds)
|
sample = up_block(sample, latent_embeds, memory_state=memory_state)
|
||||||
|
|
||||||
# post-process
|
# post-process
|
||||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample)
|
sample = self.conv_out(sample, memory_state=memory_state)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def wavelet_blur(image: Tensor, radius: int):
|
def wavelet_blur(image: Tensor, radius):
|
||||||
"""
|
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||||
Apply wavelet blur to the input tensor.
|
if radius > max_safe_radius:
|
||||||
"""
|
radius = max_safe_radius
|
||||||
# input shape: (1, 3, H, W)
|
|
||||||
# convolution kernel
|
num_channels = image.shape[1]
|
||||||
|
|
||||||
kernel_vals = [
|
kernel_vals = [
|
||||||
[0.0625, 0.125, 0.0625],
|
[0.0625, 0.125, 0.0625],
|
||||||
[0.125, 0.25, 0.125],
|
[0.125, 0.25, 0.125],
|
||||||
[0.0625, 0.125, 0.0625],
|
[0.0625, 0.125, 0.0625],
|
||||||
]
|
]
|
||||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||||
# add channel dimensions to the kernel to make it a 4D tensor
|
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||||
kernel = kernel[None, None]
|
|
||||||
# repeat the kernel across all input channels
|
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||||
kernel = kernel.repeat(3, 1, 1, 1)
|
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
|
||||||
# apply convolution
|
|
||||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def wavelet_decomposition(image: Tensor, levels=5):
|
def wavelet_decomposition(image: Tensor, levels: int = 5):
|
||||||
"""
|
|
||||||
Apply wavelet decomposition to the input tensor.
|
|
||||||
This function only returns the low frequency & the high frequency.
|
|
||||||
"""
|
|
||||||
high_freq = torch.zeros_like(image)
|
high_freq = torch.zeros_like(image)
|
||||||
|
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
radius = 2 ** i
|
radius = 2 ** i
|
||||||
low_freq = wavelet_blur(image, radius)
|
low_freq = wavelet_blur(image, radius)
|
||||||
high_freq += (image - low_freq)
|
high_freq.add_(image).sub_(low_freq)
|
||||||
image = low_freq
|
image = low_freq
|
||||||
|
|
||||||
return high_freq, low_freq
|
return high_freq, low_freq
|
||||||
|
|
||||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||||
"""
|
|
||||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
if content_feat.shape != style_feat.shape:
|
||||||
"""
|
# Resize style to match content spatial dimensions
|
||||||
# calculate the wavelet decomposition of the content feature
|
if len(content_feat.shape) >= 3:
|
||||||
|
# safe_interpolate_operation handles FP16 conversion automatically
|
||||||
|
style_feat = safe_interpolate_operation(
|
||||||
|
style_feat,
|
||||||
|
size=content_feat.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decompose both features into frequency components
|
||||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||||
del content_low_freq
|
del content_low_freq # Free memory immediately
|
||||||
# calculate the wavelet decomposition of the style feature
|
|
||||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||||
del style_high_freq
|
del style_high_freq # Free memory immediately
|
||||||
# reconstruct the content feature with the style's high frequency
|
|
||||||
return content_high_freq + style_low_freq
|
if content_high_freq.shape != style_low_freq.shape:
|
||||||
|
style_low_freq = safe_interpolate_operation(
|
||||||
|
style_low_freq,
|
||||||
|
size=content_high_freq.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
content_high_freq.add_(style_low_freq)
|
||||||
|
|
||||||
|
return content_high_freq.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
class VideoAutoencoderKL(nn.Module):
|
class VideoAutoencoderKL(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1368,9 +1629,12 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
time_receptive_field: _receptive_field_t = "full",
|
time_receptive_field: _receptive_field_t = "full",
|
||||||
use_quant_conv: bool = False,
|
use_quant_conv: bool = False,
|
||||||
use_post_quant_conv: bool = False,
|
use_post_quant_conv: bool = False,
|
||||||
|
slicing_sample_min_size = 4,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self.slicing_sample_min_size = slicing_sample_min_size
|
||||||
|
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
|
||||||
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
||||||
block_out_channels = (128, 256, 512, 512)
|
block_out_channels = (128, 256, 512, 512)
|
||||||
down_block_types = ("DownEncoderBlock3D",) * 4
|
down_block_types = ("DownEncoderBlock3D",) * 4
|
||||||
@ -1438,9 +1702,11 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
|
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
|
||||||
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
|
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
|
||||||
|
|
||||||
|
self.use_slicing = True
|
||||||
|
|
||||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
||||||
h = self.slicing_encode(x)
|
h = self.slicing_encode(x)
|
||||||
posterior = DiagonalGaussianDistribution(h).sample()
|
posterior = DiagonalGaussianDistribution(h).mode()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (posterior,)
|
return (posterior,)
|
||||||
@ -1458,29 +1724,71 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self, x: torch.Tensor
|
self, x, memory_state
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_x = x.to(self.device)
|
_x = x.to(self.device)
|
||||||
h = self.encoder(_x,)
|
h = self.encoder(_x, memory_state=memory_state)
|
||||||
if self.quant_conv is not None:
|
if self.quant_conv is not None:
|
||||||
output = self.quant_conv(h)
|
output = self.quant_conv(h, memory_state=memory_state)
|
||||||
else:
|
else:
|
||||||
output = h
|
output = h
|
||||||
return output.to(x.device)
|
return output.to(x.device)
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self, z: torch.Tensor
|
self, z, memory_state
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
latent = z.to(self.device)
|
_z = z.to(self.device)
|
||||||
|
|
||||||
if self.post_quant_conv is not None:
|
if self.post_quant_conv is not None:
|
||||||
latent = self.post_quant_conv(latent)
|
_z = self.post_quant_conv(_z, memory_state=memory_state)
|
||||||
output = self.decoder(latent)
|
|
||||||
|
output = self.decoder(_z, memory_state=memory_state)
|
||||||
return output.to(z.device)
|
return output.to(z.device)
|
||||||
|
|
||||||
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
sp_size =1
|
||||||
|
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
|
||||||
|
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
|
||||||
|
encoded_slices = [
|
||||||
|
self._encode(
|
||||||
|
torch.cat((x[:, :, :1], x_slices[0]), dim=2),
|
||||||
|
memory_state=MemoryState.INITIALIZING,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for x_idx in range(1, len(x_slices)):
|
||||||
|
encoded_slices.append(
|
||||||
|
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE)
|
||||||
|
)
|
||||||
|
out = torch.cat(encoded_slices, dim=2)
|
||||||
|
modules_with_memory = [m for m in self.modules()
|
||||||
|
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
|
||||||
|
for m in modules_with_memory:
|
||||||
|
m.memory = None
|
||||||
|
return out
|
||||||
|
else:
|
||||||
return self._encode(x)
|
return self._encode(x)
|
||||||
|
|
||||||
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
|
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||||
|
sp_size = 1
|
||||||
|
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
|
||||||
|
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
|
||||||
|
decoded_slices = [
|
||||||
|
self._decode(
|
||||||
|
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
|
||||||
|
memory_state=MemoryState.INITIALIZING
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for z_idx in range(1, len(z_slices)):
|
||||||
|
decoded_slices.append(
|
||||||
|
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE)
|
||||||
|
)
|
||||||
|
out = torch.cat(decoded_slices, dim=2)
|
||||||
|
modules_with_memory = [m for m in self.modules()
|
||||||
|
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
|
||||||
|
for m in modules_with_memory:
|
||||||
|
m.memory = None
|
||||||
|
return out
|
||||||
|
else:
|
||||||
return self._decode(z)
|
return self._decode(z)
|
||||||
|
|
||||||
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
@ -1531,6 +1839,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
self.freeze_encoder = freeze_encoder
|
self.freeze_encoder = freeze_encoder
|
||||||
self.original_image_video = None
|
self.original_image_video = None
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.set_memory_limit(0.5, 0.5)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor):
|
def forward(self, x: torch.FloatTensor):
|
||||||
with torch.no_grad() if self.freeze_encoder else nullcontext():
|
with torch.no_grad() if self.freeze_encoder else nullcontext():
|
||||||
@ -1567,8 +1876,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
|
|
||||||
target_device = comfy.model_management.get_torch_device()
|
target_device = comfy.model_management.get_torch_device()
|
||||||
self.decoder.to(target_device)
|
self.decoder.to(target_device)
|
||||||
|
if self.tiled_args.get("enable_tiling", None) is not None:
|
||||||
|
self.enable_tiling = self.tiled_args.pop("enable_tiling", False)
|
||||||
|
|
||||||
|
if self.enable_tiling:
|
||||||
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
|
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)
|
||||||
#x = super().decode(latent).squeeze(2)
|
else:
|
||||||
|
x = super().decode_(latent).squeeze(2)
|
||||||
|
|
||||||
input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w")
|
input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w")
|
||||||
if x.ndim == 4:
|
if x.ndim == 4:
|
||||||
@ -1581,6 +1895,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
|
||||||
x = wavelet_reconstruction(x, input)
|
x = wavelet_reconstruction(x, input)
|
||||||
|
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
o_h, o_w = self.img_dims
|
o_h, o_w = self.img_dims
|
||||||
x = x[..., :o_h, :o_w]
|
x = x[..., :o_h, :o_w]
|
||||||
@ -1595,8 +1910,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
||||||
# TODO
|
set_norm_limit(norm_max_mem)
|
||||||
#set_norm_limit(norm_max_mem)
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, InflatedCausalConv3d):
|
if isinstance(m, InflatedCausalConv3d):
|
||||||
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|
||||||
|
|||||||
@ -14,11 +14,12 @@ from torchvision.transforms import Lambda, Normalize
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
|
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
x = x.to(next(vae_model.parameters()).dtype)
|
||||||
if x.ndim != 5:
|
if x.ndim != 5:
|
||||||
x = x.unsqueeze(2)
|
x = x.unsqueeze(2)
|
||||||
|
|
||||||
@ -30,9 +31,6 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
if encode:
|
if encode:
|
||||||
ti_h, ti_w = tile_size
|
ti_h, ti_w = tile_size
|
||||||
ov_h, ov_w = tile_overlap
|
ov_h, ov_w = tile_overlap
|
||||||
ti_t = temporal_size
|
|
||||||
ov_t = temporal_overlap
|
|
||||||
|
|
||||||
target_d = (d + sf_t - 1) // sf_t
|
target_d = (d + sf_t - 1) // sf_t
|
||||||
target_h = (h + sf_s - 1) // sf_s
|
target_h = (h + sf_s - 1) // sf_s
|
||||||
target_w = (w + sf_s - 1) // sf_s
|
target_w = (w + sf_s - 1) // sf_s
|
||||||
@ -41,21 +39,44 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
ti_w = max(1, tile_size[1] // sf_s)
|
ti_w = max(1, tile_size[1] // sf_s)
|
||||||
ov_h = max(0, tile_overlap[0] // sf_s)
|
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||||
ov_w = max(0, tile_overlap[1] // sf_s)
|
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||||
ti_t = max(1, temporal_size // sf_t)
|
|
||||||
ov_t = max(0, temporal_overlap // sf_t)
|
|
||||||
|
|
||||||
target_d = d * sf_t
|
target_d = d * sf_t
|
||||||
target_h = h * sf_s
|
target_h = h * sf_s
|
||||||
target_w = w * sf_s
|
target_w = w * sf_s
|
||||||
|
|
||||||
stride_t = max(1, ti_t - ov_t)
|
|
||||||
stride_h = max(1, ti_h - ov_h)
|
stride_h = max(1, ti_h - ov_h)
|
||||||
stride_w = max(1, ti_w - ov_w)
|
stride_w = max(1, ti_w - ov_w)
|
||||||
|
|
||||||
storage_device = torch.device("cpu")
|
storage_device = torch.device("cpu")
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
count = None
|
count = None
|
||||||
|
|
||||||
|
def run_temporal_chunks(spatial_tile):
|
||||||
|
chunk_results = []
|
||||||
|
t_dim_size = spatial_tile.shape[2]
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
input_chunk = temporal_size
|
||||||
|
else:
|
||||||
|
input_chunk = max(1, temporal_size // sf_t)
|
||||||
|
|
||||||
|
for i in range(0, t_dim_size, input_chunk):
|
||||||
|
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
out = vae_model.encode(t_chunk)
|
||||||
|
else:
|
||||||
|
out = vae_model.decode_(t_chunk)
|
||||||
|
|
||||||
|
if isinstance(out, (tuple, list)): out = out[0]
|
||||||
|
|
||||||
|
if out.ndim == 4: out = out.unsqueeze(2)
|
||||||
|
|
||||||
|
chunk_results.append(out.to(storage_device))
|
||||||
|
|
||||||
|
return torch.cat(chunk_results, dim=2)
|
||||||
|
|
||||||
ramp_cache = {}
|
ramp_cache = {}
|
||||||
def get_ramp(steps):
|
def get_ramp(steps):
|
||||||
if steps not in ramp_cache:
|
if steps not in ramp_cache:
|
||||||
@ -63,9 +84,8 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||||
return ramp_cache[steps]
|
return ramp_cache[steps]
|
||||||
|
|
||||||
bar = ProgressBar(d // stride_t)
|
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
||||||
for t_idx in range(0, d, stride_t):
|
bar = ProgressBar(total_tiles)
|
||||||
t_end = min(t_idx + ti_t, d)
|
|
||||||
|
|
||||||
for y_idx in range(0, h, stride_h):
|
for y_idx in range(0, h, stride_h):
|
||||||
y_end = min(y_idx + ti_h, h)
|
y_end = min(y_idx + ti_h, h)
|
||||||
@ -73,49 +93,30 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
for x_idx in range(0, w, stride_w):
|
for x_idx in range(0, w, stride_w):
|
||||||
x_end = min(x_idx + ti_w, w)
|
x_end = min(x_idx + ti_w, w)
|
||||||
|
|
||||||
tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end]
|
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||||
|
|
||||||
if encode:
|
# Run VAE
|
||||||
tile_out = vae_model.encode(tile_x)[0]
|
tile_out = run_temporal_chunks(tile_x)
|
||||||
else:
|
|
||||||
tile_out = vae_model.decode_(tile_x)
|
|
||||||
|
|
||||||
if tile_out.ndim == 4:
|
|
||||||
tile_out = tile_out.unsqueeze(2)
|
|
||||||
|
|
||||||
tile_out = tile_out.to(storage_device).float()
|
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||||
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
|
||||||
if encode:
|
if encode:
|
||||||
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
|
|
||||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||||
|
|
||||||
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
|
|
||||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||||
else:
|
else:
|
||||||
ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2]
|
|
||||||
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||||
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||||
|
|
||||||
cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2))
|
|
||||||
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||||
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||||
|
|
||||||
w_t = torch.ones((tile_out.shape[2],), device=storage_device)
|
|
||||||
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||||
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||||
|
|
||||||
if cur_ov_t > 0:
|
|
||||||
r = get_ramp(cur_ov_t)
|
|
||||||
if t_idx > 0: w_t[:cur_ov_t] = r
|
|
||||||
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
|
|
||||||
|
|
||||||
if cur_ov_h > 0:
|
if cur_ov_h > 0:
|
||||||
r = get_ramp(cur_ov_h)
|
r = get_ramp(cur_ov_h)
|
||||||
if y_idx > 0: w_h[:cur_ov_h] = r
|
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||||
@ -126,14 +127,19 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
if x_idx > 0: w_w[:cur_ov_w] = r
|
if x_idx > 0: w_w[:cur_ov_w] = r
|
||||||
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
||||||
|
|
||||||
final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||||
|
|
||||||
|
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||||
|
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||||
|
|
||||||
tile_out.mul_(final_weight)
|
tile_out.mul_(final_weight)
|
||||||
result[:, :, ts:te, ys:ye, xs:xe] += tile_out
|
|
||||||
count[:, :, ts:te, ys:ye, xs:xe] += final_weight
|
|
||||||
|
|
||||||
del tile_out, final_weight, tile_x, w_t, w_h, w_w
|
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||||
|
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||||
|
|
||||||
|
del tile_out, final_weight, tile_x, w_h, w_w
|
||||||
bar.update(1)
|
bar.update(1)
|
||||||
|
|
||||||
result.div_(count.clamp(min=1e-6))
|
result.div_(count.clamp(min=1e-6))
|
||||||
|
|
||||||
if result.device != x.device:
|
if result.device != x.device:
|
||||||
@ -253,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
io.Int.Input("spatial_tile_size", default = 512, min = -1),
|
io.Int.Input("spatial_tile_size", default = 512, min = -1),
|
||||||
io.Int.Input("temporal_tile_size", default = 8, min = -1),
|
io.Int.Input("temporal_tile_size", default = 8, min = -1),
|
||||||
io.Int.Input("spatial_overlap", default = 64, min = -1),
|
io.Int.Input("spatial_overlap", default = 64, min = -1),
|
||||||
io.Int.Input("temporal_overlap", default = 8, min = -1),
|
io.Boolean.Input("enable_tiling", default=False)
|
||||||
],
|
],
|
||||||
outputs = [
|
outputs = [
|
||||||
io.Latent.Output("vae_conditioning")
|
io.Latent.Output("vae_conditioning")
|
||||||
@ -261,7 +267,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
|
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||||
device = vae.patcher.load_device
|
device = vae.patcher.load_device
|
||||||
|
|
||||||
offload_device = comfy.model_management.intermediate_device()
|
offload_device = comfy.model_management.intermediate_device()
|
||||||
@ -296,9 +302,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
vae_model.original_image_video = images
|
vae_model.original_image_video = images
|
||||||
|
|
||||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||||
"temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap}
|
"temporal_size":temporal_tile_size}
|
||||||
vae_model.tiled_args = args
|
if enable_tiling:
|
||||||
latent = tiled_vae(images, vae_model, encode=True, **args)
|
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||||
|
else:
|
||||||
|
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
||||||
|
|
||||||
|
args["enable_tiling"] = enable_tiling
|
||||||
|
vae_model.tiled_args = args
|
||||||
|
|
||||||
vae_model = vae_model.to(offload_device)
|
vae_model = vae_model.to(offload_device)
|
||||||
vae_model.img_dims = [o_h, o_w]
|
vae_model.img_dims = [o_h, o_w]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user