mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-19 16:16:00 +08:00
Reduce LTX VAE VRAM usage and save use cases from OOMs/Tiler (#13013)
* ltx: vae: scale the chunk size with the users VRAM Scale this linearly down for users with low VRAM. * ltx: vae: free non-chunking recursive intermediates * ltx: vae: cleanup some intermediates The conv layer can be the VRAM peak and it does a torch.cat. So cleanup the pieces of the cat. Also clear our the cache ASAP as each layer detect its end as this VAE surges in VRAM at the end due to the ended padding increasing the size of the final frame convolutions off-the-books to the chunker. So if all the earlier layers free up their cache it can offset that surge. Its a fragmentation nightmare, and the chance of it having to recache the pyt allocator is very high, but you wont OOM.
This commit is contained in:
parent
ed7c2c6579
commit
1a157e1f97
@ -65,9 +65,13 @@ class CausalConv3d(nn.Module):
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
|
||||
@ -297,7 +297,23 @@ class Encoder(nn.Module):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||
|
||||
def get_max_chunk_size(device: torch.device) -> int:
|
||||
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||
|
||||
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||
return MIN_CHUNK_SIZE
|
||||
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||
return MAX_CHUNK_SIZE
|
||||
|
||||
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||
)
|
||||
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@ -525,8 +541,11 @@ class Decoder(nn.Module):
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
output = []
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
def run_up(idx, sample_ref, ended):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
@ -554,13 +573,21 @@ class Decoder(nn.Module):
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
run_up(idx + 1, next_sample_ref, ended)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
run_up(0, sample, True)
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, [sample], True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user