diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 70d612e86..b8341edbc 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -1,11 +1,11 @@ from typing import Tuple, Union +import threading import torch import torch.nn as nn import comfy.ops ops = comfy.ops.disable_weight_init - class CausalConv3d(nn.Module): def __init__( self, @@ -42,23 +42,34 @@ class CausalConv3d(nn.Module): padding_mode=spatial_padding_mode, groups=groups, ) + self.temporal_cache_state={} def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x + tid = threading.get_ident() + + cached, is_end = self.temporal_cache_state.get(tid, (None, False)) + if cached is None: + padding_length = self.time_kernel_size - 1 + if not causal: + padding_length = padding_length // 2 + if x.shape[2] == 0: + return x + cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1)) + pieces = [ cached, x ] + if is_end and not causal: + pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + + needs_caching = not is_end + if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + + x = torch.cat(pieces, dim=2) + + if needs_caching: + self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + + return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] @property def weight(self): diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 75ed069ad..cbfdf412d 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,4 +1,5 @@ from __future__ import annotations +import threading import torch from torch import nn from functools import partial @@ -6,12 +7,35 @@ import math from einops import rearrange from typing import List, Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd +from .causal_conv3d import CausalConv3d from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops +from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def mark_conv3d_ended(module): + tid = threading.get_ident() + for _, m in module.named_modules(): + if isinstance(m, CausalConv3d): + current = m.temporal_cache_state.get(tid, (None, False)) + m.temporal_cache_state[tid] = (current[0], True) + +def split2(tensor, split_point, dim=2): + return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) + +def add_exchange_cache(dest, cache_in, new_input, dim=2): + if dest is not None: + if cache_in is not None: + cache_to_dest = min(dest.shape[dim], cache_in.shape[dim]) + lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim) + lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim) + lead_in_dest.add_(lead_in_source) + body, new_input = split2(new_input, dest.shape[dim], dim) + dest.add_(body) + return torch_cat_if_needed([cache_in, new_input], dim=dim) + class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. @@ -205,7 +229,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) @@ -254,6 +278,22 @@ class Encoder(nn.Module): return sample + def forward(self, *args, **kwargs): + #No encoder support so just flag the end so it doesnt use the cache. + mark_conv3d_ended(self) + try: + return self.forward_orig(*args, **kwargs) + finally: + tid = threading.get_ident() + for _, module in self.named_modules(): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + + +MAX_CHUNK_SIZE=(128 * 1024 ** 2) class Decoder(nn.Module): r""" @@ -341,18 +381,6 @@ class Decoder(nn.Module): timestep_conditioning=timestep_conditioning, spatial_padding_mode=spatial_padding_mode, ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) block = ResnetBlock3D( @@ -428,8 +456,9 @@ class Decoder(nn.Module): ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - def forward( + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, @@ -437,6 +466,7 @@ class Decoder(nn.Module): r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) checkpoint_fn = ( @@ -445,24 +475,12 @@ class Decoder(nn.Module): else lambda x: x ) - scaled_timestep = None + timestep_shift_scale = None if self.timestep_conditioning: assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - sample = self.conv_norm_out(sample) - - if self.timestep_conditioning: embedded_timestep = self.last_time_embedder( timestep=scaled_timestep.flatten(), resolution=None, @@ -483,16 +501,62 @@ class Decoder(nn.Module): embedded_timestep.shape[-2], embedded_timestep.shape[-1], ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift + timestep_shift_scale = ada_values.unbind(dim=1) - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) + output = [] + + def run_up(idx, sample, ended): + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + output.append(sample) + return + + up_block = self.up_blocks[idx] + if (ended): + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) + + run_up(0, sample, True) + sample = torch.cat(output, dim=2) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) return sample + def forward(self, *args, **kwargs): + try: + return self.forward_orig(*args, **kwargs) + finally: + for _, module in self.named_modules(): + #ComfyUI doesn't thread this kind of stuff today, but just incase + #we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + class UNetMidBlock3D(nn.Module): """ @@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module): ) self.residual = residual self.out_channels_reduction_factor = out_channels_reduction_factor + self.temporal_cache_state = {} def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): + tid = threading.get_ident() + cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True)) + y = self.conv(x, causal=causal) + y = rearrange( + y, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv: + y = y[:, :, 1:, :, :] + drop_first_conv = False if self.residual: # Reshape and duplicate the input to match the output shape x_in = rearrange( @@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module): ) num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: + if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res: x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x + drop_first_res = False + + if y.shape[2] == 0: + y = None + + cached = add_exchange_cache(y, cached, x_in, dim=2) + self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res) + + else: + self.temporal_cache_state[tid] = (None, drop_first_conv, False) + + return y class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: @@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module): torch.randn(4, in_channels) / in_channels**0.5 ) + self.temporal_cache_state={} + def _feed_spatial_noise( self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor ) -> torch.FloatTensor: @@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module): input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states + tid = threading.get_ident() + cached = self.temporal_cache_state.get(tid, None) + cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2) + self.temporal_cache_state[tid] = cached - return output_tensor + return hidden_states def patchify(x, patch_size_hw, patch_size_t=1): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1ae3ef034..5a22ef030 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae(): import xformers.ops def torch_cat_if_needed(xl, dim): + xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: return torch.cat(xl, dim) - else: + elif len(xl) == 1: return xl[0] + else: + return None def get_timestep_embedding(timesteps, embedding_dim): """