From 0e25a6936ef41a56af87a4af174fa519da73b37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:15:10 +0300 Subject: [PATCH] Reduce video tiny VAE peak VRAM and decode time (CORE-127) (#13617) * Update taehv.py * Simplify * Simplify pixel_unshuffle dispatch --- comfy/taesd/taehv.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 6c06ce19d..696013200 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -7,6 +7,7 @@ from tqdm.auto import tqdm from collections import namedtuple, deque import comfy.ops +import comfy.model_management operations=comfy.ops.disable_weight_init DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) @@ -47,11 +48,14 @@ class TGrow(nn.Module): x = self.conv(x) return x.reshape(-1, C, H, W) -def apply_model_with_memblocks(model, x, parallel, show_progress_bar): +def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None, + patch_size=1, decode=False): B, T, C, H, W = x.shape if parallel: x = x.reshape(B*T, C, H, W) + if not decode and patch_size > 1: + x = F.pixel_unshuffle(x, patch_size) # parallel over input timesteps, iterate over blocks for b in tqdm(model, disable=not show_progress_bar): if isinstance(b, MemBlock): @@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): x = b(x, mem) else: x = b(x) - BT, C, H, W = x.shape - T = BT // B - x = x.view(B, T, C, H, W) + if decode and patch_size > 1: + x = F.pixel_shuffle(x, patch_size) + x = x.view(B, x.shape[0] // B, *x.shape[1:]) + x = x.to(output_device) else: out = [] - work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + # Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views). + # Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode). + work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)]) progress_bar = tqdm(range(T), disable=not show_progress_bar) mem = [None] * len(model) while work_queue: xt, i = work_queue.popleft() if i == 0: progress_bar.update(1) + if not decode and patch_size > 1: + xt = F.pixel_unshuffle(xt, patch_size) if i == len(model): - out.append(xt) + if decode and patch_size > 1: + xt = F.pixel_shuffle(xt, patch_size) + out.append(xt.to(output_device)) del xt else: b = model[i] @@ -165,24 +176,20 @@ class TAEHV(nn.Module): def encode(self, x, **kwargs): x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if self.patch_size > 1: - B, T, C, H, W = x.shape - x = x.reshape(B * T, C, H, W) - x = F.pixel_unshuffle(x, self.patch_size) - x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) if x.shape[1] % self.t_downscale != 0: # pad at end to multiple of t_downscale n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) - x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar, + patch_size=self.patch_size).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: - x = F.pixel_shuffle(x, self.patch_size) + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar, + output_device=comfy.model_management.intermediate_device(), + patch_size=self.patch_size, decode=True) return x[:, self.frames_to_trim:].movedim(2, 1)