mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
Reduce LTX2 VAE VRAM consumption (#12028)
* causal_video_ae: Remove attention ResNet
This attention_head_dim argument does not exist on this constructor so
this is dead code. Remove as generic attention mid VAE conflicts with
temporal roll.
* ltx-vae: consoldate causal/non-causal code paths
* ltx-vae: add cache rolling adder
* ltx-vae: use cached adder for resnet
* ltx-vae: Implement rolling VAE
Implement a temporal rolling VAE for the LTX2 VAE.
Usually when doing temporal rolling VAEs you can just chunk on time relying
on causality and cache behind you as you go. The LTX VAE is however
non-causal.
So go whole hog and implement per layer run ahead and backpressure between
the decoder layers using recursive state beween the layers.
Operations are ammended with temporal_cache_state{} which they can use to
hold any state then need for partial execution. Convolutions cache their
inputs behind the up to N-1 frames, and skip connections need to cache the
mismatch between convolution input and output that happens due to missing
future (non-causal) input.
Each call to run_up() processes a layer accross a range on input that
may or may not be complete. It goes depth first to process as much as
possible to try and digest frames to the final output ASAP. If layers run
out of input due to convolution losses, they simply return without action
effectively applying back-pressure to the earlier layers. As the earlier
layers do more work and caller deeper, the partial states are reconciled
and output continues to digest depth first as much as possible.
Chunking is done using a size quota rather than a fixed frame length and
any layer can initiate chunking, and multiple layers can chunk at different
granulatiries. This remove the old limitation of always having to process
1 latent frame to entirety and having to hold 8 full decoded frames as
the VRAM peak.
This commit is contained in:
parent
8490eedadf
commit
0fd1b78736
@ -1,11 +1,11 @@
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
|
|||||||
padding_mode=spatial_padding_mode,
|
padding_mode=spatial_padding_mode,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
)
|
||||||
|
self.temporal_cache_state={}
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True):
|
def forward(self, x, causal: bool = True):
|
||||||
if causal:
|
tid = threading.get_ident()
|
||||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
|
||||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
|
||||||
)
|
if cached is None:
|
||||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
padding_length = self.time_kernel_size - 1
|
||||||
else:
|
if not causal:
|
||||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
padding_length = padding_length // 2
|
||||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
if x.shape[2] == 0:
|
||||||
)
|
return x
|
||||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
|
||||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
pieces = [ cached, x ]
|
||||||
)
|
if is_end and not causal:
|
||||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
needs_caching = not is_end
|
||||||
|
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||||
|
needs_caching = False
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||||
|
|
||||||
|
x = torch.cat(pieces, dim=2)
|
||||||
|
|
||||||
|
if needs_caching:
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||||
|
|
||||||
|
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weight(self):
|
def weight(self):
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -6,12 +7,35 @@ import math
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||||
|
from .causal_conv3d import CausalConv3d
|
||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def mark_conv3d_ended(module):
|
||||||
|
tid = threading.get_ident()
|
||||||
|
for _, m in module.named_modules():
|
||||||
|
if isinstance(m, CausalConv3d):
|
||||||
|
current = m.temporal_cache_state.get(tid, (None, False))
|
||||||
|
m.temporal_cache_state[tid] = (current[0], True)
|
||||||
|
|
||||||
|
def split2(tensor, split_point, dim=2):
|
||||||
|
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||||
|
|
||||||
|
def add_exchange_cache(dest, cache_in, new_input, dim=2):
|
||||||
|
if dest is not None:
|
||||||
|
if cache_in is not None:
|
||||||
|
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
|
||||||
|
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
|
||||||
|
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
|
||||||
|
lead_in_dest.add_(lead_in_source)
|
||||||
|
body, new_input = split2(new_input, dest.shape[dim], dim)
|
||||||
|
dest.add_(body)
|
||||||
|
return torch_cat_if_needed([cache_in, new_input], dim=dim)
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||||
@ -205,7 +229,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Encoder` class."""
|
r"""The forward method of the `Encoder` class."""
|
||||||
|
|
||||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
@ -254,6 +278,22 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
#No encoder support so just flag the end so it doesnt use the cache.
|
||||||
|
mark_conv3d_ended(self)
|
||||||
|
try:
|
||||||
|
return self.forward_orig(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
tid = threading.get_ident()
|
||||||
|
for _, module in self.named_modules():
|
||||||
|
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||||
|
# we key on the thread to make it thread safe.
|
||||||
|
tid = threading.get_ident()
|
||||||
|
if hasattr(module, "temporal_cache_state"):
|
||||||
|
module.temporal_cache_state.pop(tid, None)
|
||||||
|
|
||||||
|
|
||||||
|
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -341,18 +381,6 @@ class Decoder(nn.Module):
|
|||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "attn_res_x":
|
|
||||||
block = UNetMidBlock3D(
|
|
||||||
dims=dims,
|
|
||||||
in_channels=input_channel,
|
|
||||||
num_layers=block_params["num_layers"],
|
|
||||||
resnet_groups=norm_num_groups,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
inject_noise=block_params.get("inject_noise", False),
|
|
||||||
timestep_conditioning=timestep_conditioning,
|
|
||||||
attention_head_dim=block_params["attention_head_dim"],
|
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
|
||||||
)
|
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||||
block = ResnetBlock3D(
|
block = ResnetBlock3D(
|
||||||
@ -428,8 +456,9 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||||
def forward(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
timestep: Optional[torch.Tensor] = None,
|
timestep: Optional[torch.Tensor] = None,
|
||||||
@ -437,6 +466,7 @@ class Decoder(nn.Module):
|
|||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
batch_size = sample.shape[0]
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
|
mark_conv3d_ended(self.conv_in)
|
||||||
sample = self.conv_in(sample, causal=self.causal)
|
sample = self.conv_in(sample, causal=self.causal)
|
||||||
|
|
||||||
checkpoint_fn = (
|
checkpoint_fn = (
|
||||||
@ -445,24 +475,12 @@ class Decoder(nn.Module):
|
|||||||
else lambda x: x
|
else lambda x: x
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_timestep = None
|
timestep_shift_scale = None
|
||||||
if self.timestep_conditioning:
|
if self.timestep_conditioning:
|
||||||
assert (
|
assert (
|
||||||
timestep is not None
|
timestep is not None
|
||||||
), "should pass timestep with timestep_conditioning=True"
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||||
|
|
||||||
for up_block in self.up_blocks:
|
|
||||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
|
||||||
sample = checkpoint_fn(up_block)(
|
|
||||||
sample, causal=self.causal, timestep=scaled_timestep
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
|
||||||
|
|
||||||
sample = self.conv_norm_out(sample)
|
|
||||||
|
|
||||||
if self.timestep_conditioning:
|
|
||||||
embedded_timestep = self.last_time_embedder(
|
embedded_timestep = self.last_time_embedder(
|
||||||
timestep=scaled_timestep.flatten(),
|
timestep=scaled_timestep.flatten(),
|
||||||
resolution=None,
|
resolution=None,
|
||||||
@ -483,16 +501,62 @@ class Decoder(nn.Module):
|
|||||||
embedded_timestep.shape[-2],
|
embedded_timestep.shape[-2],
|
||||||
embedded_timestep.shape[-1],
|
embedded_timestep.shape[-1],
|
||||||
)
|
)
|
||||||
shift, scale = ada_values.unbind(dim=1)
|
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||||
sample = sample * (1 + scale) + shift
|
|
||||||
|
|
||||||
sample = self.conv_act(sample)
|
output = []
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
|
||||||
|
def run_up(idx, sample, ended):
|
||||||
|
if idx >= len(self.up_blocks):
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
if timestep_shift_scale is not None:
|
||||||
|
shift, scale = timestep_shift_scale
|
||||||
|
sample = sample * (1 + scale) + shift
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
if ended:
|
||||||
|
mark_conv3d_ended(self.conv_out)
|
||||||
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
|
if sample is not None and sample.shape[2] > 0:
|
||||||
|
output.append(sample)
|
||||||
|
return
|
||||||
|
|
||||||
|
up_block = self.up_blocks[idx]
|
||||||
|
if (ended):
|
||||||
|
mark_conv3d_ended(up_block)
|
||||||
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
|
sample = checkpoint_fn(up_block)(
|
||||||
|
sample, causal=self.causal, timestep=scaled_timestep
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||||
|
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
total_bytes = sample.numel() * sample.element_size()
|
||||||
|
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||||
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
|
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||||
|
|
||||||
|
run_up(0, sample, True)
|
||||||
|
sample = torch.cat(output, dim=2)
|
||||||
|
|
||||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.forward_orig(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
for _, module in self.named_modules():
|
||||||
|
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
||||||
|
#we key on the thread to make it thread safe.
|
||||||
|
tid = threading.get_ident()
|
||||||
|
if hasattr(module, "temporal_cache_state"):
|
||||||
|
module.temporal_cache_state.pop(tid, None)
|
||||||
|
|
||||||
|
|
||||||
class UNetMidBlock3D(nn.Module):
|
class UNetMidBlock3D(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
)
|
)
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||||
|
self.temporal_cache_state = {}
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||||
|
tid = threading.get_ident()
|
||||||
|
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
|
||||||
|
y = self.conv(x, causal=causal)
|
||||||
|
y = rearrange(
|
||||||
|
y,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
|
||||||
|
y = y[:, :, 1:, :, :]
|
||||||
|
drop_first_conv = False
|
||||||
if self.residual:
|
if self.residual:
|
||||||
# Reshape and duplicate the input to match the output shape
|
# Reshape and duplicate the input to match the output shape
|
||||||
x_in = rearrange(
|
x_in = rearrange(
|
||||||
@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
)
|
)
|
||||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||||
if self.stride[0] == 2:
|
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
|
||||||
x_in = x_in[:, :, 1:, :, :]
|
x_in = x_in[:, :, 1:, :, :]
|
||||||
x = self.conv(x, causal=causal)
|
drop_first_res = False
|
||||||
x = rearrange(
|
|
||||||
x,
|
if y.shape[2] == 0:
|
||||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
y = None
|
||||||
p1=self.stride[0],
|
|
||||||
p2=self.stride[1],
|
cached = add_exchange_cache(y, cached, x_in, dim=2)
|
||||||
p3=self.stride[2],
|
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
|
||||||
)
|
|
||||||
if self.stride[0] == 2:
|
else:
|
||||||
x = x[:, :, 1:, :, :]
|
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
|
||||||
if self.residual:
|
|
||||||
x = x + x_in
|
return y
|
||||||
return x
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||||
@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
|
|||||||
torch.randn(4, in_channels) / in_channels**0.5
|
torch.randn(4, in_channels) / in_channels**0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.temporal_cache_state={}
|
||||||
|
|
||||||
def _feed_spatial_noise(
|
def _feed_spatial_noise(
|
||||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
|
|||||||
|
|
||||||
input_tensor = self.conv_shortcut(input_tensor)
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
output_tensor = input_tensor + hidden_states
|
tid = threading.get_ident()
|
||||||
|
cached = self.temporal_cache_state.get(tid, None)
|
||||||
|
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
|
||||||
|
self.temporal_cache_state[tid] = cached
|
||||||
|
|
||||||
return output_tensor
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||||
|
|||||||
@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
|
|||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
def torch_cat_if_needed(xl, dim):
|
def torch_cat_if_needed(xl, dim):
|
||||||
|
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||||
if len(xl) > 1:
|
if len(xl) > 1:
|
||||||
return torch.cat(xl, dim)
|
return torch.cat(xl, dim)
|
||||||
else:
|
elif len(xl) == 1:
|
||||||
return xl[0]
|
return xl[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user