diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index b8341edbc..356394239 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -65,9 +65,13 @@ class CausalConv3d(nn.Module): self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) x = torch.cat(pieces, dim=2) + del pieces + del cached if needs_caching: self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + elif is_end: + self.temporal_cache_state[tid] = (None, True) return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 9f14f64a5..0504140ef 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -297,7 +297,23 @@ class Encoder(nn.Module): module.temporal_cache_state.pop(tid, None) -MAX_CHUNK_SIZE=(128 * 1024 ** 2) +MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 +MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3 +MIN_CHUNK_SIZE = 32 * 1024 ** 2 +MAX_CHUNK_SIZE = 128 * 1024 ** 2 + +def get_max_chunk_size(device: torch.device) -> int: + total_memory = comfy.model_management.get_total_memory(dev=device) + + if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING: + return MIN_CHUNK_SIZE + if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING: + return MAX_CHUNK_SIZE + + interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / ( + MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING + ) + return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE)) class Decoder(nn.Module): r""" @@ -525,8 +541,11 @@ class Decoder(nn.Module): timestep_shift_scale = ada_values.unbind(dim=1) output = [] + max_chunk_size = get_max_chunk_size(sample.device) - def run_up(idx, sample, ended): + def run_up(idx, sample_ref, ended): + sample = sample_ref[0] + sample_ref[0] = None if idx >= len(self.up_blocks): sample = self.conv_norm_out(sample) if timestep_shift_scale is not None: @@ -554,13 +573,21 @@ class Decoder(nn.Module): return total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE - samples = torch.chunk(sample, chunks=num_chunks, dim=2) + num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size - for chunk_idx, sample1 in enumerate(samples): - run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) + if num_chunks == 1: + # when we are not chunking, detach our x so the callee can free it as soon as they are done + next_sample_ref = [sample] + del sample + run_up(idx + 1, next_sample_ref, ended) + return + else: + samples = torch.chunk(sample, chunks=num_chunks, dim=2) - run_up(0, sample, True) + for chunk_idx, sample1 in enumerate(samples): + run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) + + run_up(0, [sample], True) sample = torch.cat(output, dim=2) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)