Compare commits

..

3 Commits

Author SHA1 Message Date
xmarre
3d091d7797
Merge fdcc38b9ea into 1a157e1f97 2026-03-17 16:32:54 -05:00
rattus
1a157e1f97
Reduce LTX VAE VRAM usage and save use cases from OOMs/Tiler (#13013)
* ltx: vae: scale the chunk size with the users VRAM

Scale this linearly down for users with low VRAM.

* ltx: vae: free non-chunking recursive intermediates

* ltx: vae: cleanup some intermediates

The conv layer can be the VRAM peak and it does a torch.cat. So cleanup
the pieces of the cat. Also clear our the cache ASAP as each layer detect
its end as this VAE surges in VRAM at the end due to the ended padding
increasing the size of the final frame convolutions off-the-books to
the chunker. So if all the earlier layers free up their cache it can
offset that surge.

Its a fragmentation nightmare, and the chance of it having to recache the
pyt allocator is very high, but you wont OOM.
2026-03-17 17:32:43 -04:00
Christian Byrne
ed7c2c6579
Mark weight_dtype as advanced input in Load Diffusion Model node (#12769)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Mark the weight_dtype parameter in UNETLoader (Load Diffusion Model) as
an advanced input to reduce UI complexity for new users. The parameter
is now hidden behind an expandable Advanced section, matching the
pattern used for other advanced inputs like device, tile_size, and
overlap.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbaf1-d3c0-718e-a325-318baba86dec
2026-03-17 07:24:00 -07:00
3 changed files with 39 additions and 8 deletions

View File

@ -65,9 +65,13 @@ class CausalConv3d(nn.Module):
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@ -297,7 +297,23 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
class Decoder(nn.Module):
r"""
@ -525,8 +541,11 @@ class Decoder(nn.Module):
timestep_shift_scale = ada_values.unbind(dim=1)
output = []
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample, ended):
def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
@ -554,13 +573,21 @@ class Decoder(nn.Module):
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
run_up(0, sample, True)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

View File

@ -952,7 +952,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"