Compare commits

..

1 Commits

Author SHA1 Message Date
xmarre
95194200d2
Merge fdcc38b9ea into 379fbd1a82 2026-03-17 06:48:21 +00:00
3 changed files with 8 additions and 39 deletions

View File

@ -65,13 +65,9 @@ class CausalConv3d(nn.Module):
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@ -297,23 +297,7 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@ -541,11 +525,8 @@ class Decoder(nn.Module):
timestep_shift_scale = ada_values.unbind(dim=1)
output = []
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
@ -573,21 +554,13 @@ class Decoder(nn.Module):
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

View File

@ -952,7 +952,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"