ltx: vae: scale the chunk size with the users VRAM

Scale this linearly down for users with low VRAM.
This commit is contained in:
Rattus 2026-03-14 19:46:32 +10:00
parent 7a16e8aa4e
commit d4a3cfc649

View File

@ -297,7 +297,23 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None) module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2) MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
class Decoder(nn.Module): class Decoder(nn.Module):
r""" r"""
@ -525,6 +541,7 @@ class Decoder(nn.Module):
timestep_shift_scale = ada_values.unbind(dim=1) timestep_shift_scale = ada_values.unbind(dim=1)
output = [] output = []
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample, ended): def run_up(idx, sample, ended):
if idx >= len(self.up_blocks): if idx >= len(self.up_blocks):
@ -554,7 +571,7 @@ class Decoder(nn.Module):
return return
total_bytes = sample.numel() * sample.element_size() total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
samples = torch.chunk(sample, chunks=num_chunks, dim=2) samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples): for chunk_idx, sample1 in enumerate(samples):