mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
ltx: vae: scale the chunk size with the users VRAM
Scale this linearly down for users with low VRAM.
This commit is contained in:
parent
7a16e8aa4e
commit
d4a3cfc649
@ -297,7 +297,23 @@ class Encoder(nn.Module):
|
|||||||
module.temporal_cache_state.pop(tid, None)
|
module.temporal_cache_state.pop(tid, None)
|
||||||
|
|
||||||
|
|
||||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||||
|
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||||
|
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||||
|
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||||
|
|
||||||
|
def get_max_chunk_size(device: torch.device) -> int:
|
||||||
|
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||||
|
|
||||||
|
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||||
|
return MIN_CHUNK_SIZE
|
||||||
|
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||||
|
return MAX_CHUNK_SIZE
|
||||||
|
|
||||||
|
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||||
|
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||||
|
)
|
||||||
|
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -525,6 +541,7 @@ class Decoder(nn.Module):
|
|||||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
|
|
||||||
def run_up(idx, sample, ended):
|
def run_up(idx, sample, ended):
|
||||||
if idx >= len(self.up_blocks):
|
if idx >= len(self.up_blocks):
|
||||||
@ -554,7 +571,7 @@ class Decoder(nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
total_bytes = sample.numel() * sample.element_size()
|
total_bytes = sample.numel() * sample.element_size()
|
||||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user