mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 13:52:31 +08:00
Reduce video tiny VAE peak VRAM and decode time (CORE-127) (#13617)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Update taehv.py * Simplify * Simplify pixel_unshuffle dispatch
This commit is contained in:
parent
fce0398470
commit
0e25a6936e
@ -7,6 +7,7 @@ from tqdm.auto import tqdm
|
|||||||
from collections import namedtuple, deque
|
from collections import namedtuple, deque
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.model_management
|
||||||
operations=comfy.ops.disable_weight_init
|
operations=comfy.ops.disable_weight_init
|
||||||
|
|
||||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||||
@ -47,11 +48,14 @@ class TGrow(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x.reshape(-1, C, H, W)
|
return x.reshape(-1, C, H, W)
|
||||||
|
|
||||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
|
||||||
|
patch_size=1, decode=False):
|
||||||
|
|
||||||
B, T, C, H, W = x.shape
|
B, T, C, H, W = x.shape
|
||||||
if parallel:
|
if parallel:
|
||||||
x = x.reshape(B*T, C, H, W)
|
x = x.reshape(B*T, C, H, W)
|
||||||
|
if not decode and patch_size > 1:
|
||||||
|
x = F.pixel_unshuffle(x, patch_size)
|
||||||
# parallel over input timesteps, iterate over blocks
|
# parallel over input timesteps, iterate over blocks
|
||||||
for b in tqdm(model, disable=not show_progress_bar):
|
for b in tqdm(model, disable=not show_progress_bar):
|
||||||
if isinstance(b, MemBlock):
|
if isinstance(b, MemBlock):
|
||||||
@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
|||||||
x = b(x, mem)
|
x = b(x, mem)
|
||||||
else:
|
else:
|
||||||
x = b(x)
|
x = b(x)
|
||||||
BT, C, H, W = x.shape
|
if decode and patch_size > 1:
|
||||||
T = BT // B
|
x = F.pixel_shuffle(x, patch_size)
|
||||||
x = x.view(B, T, C, H, W)
|
x = x.view(B, x.shape[0] // B, *x.shape[1:])
|
||||||
|
x = x.to(output_device)
|
||||||
else:
|
else:
|
||||||
out = []
|
out = []
|
||||||
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
|
# Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
|
||||||
|
# Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
|
||||||
|
work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
|
||||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||||
mem = [None] * len(model)
|
mem = [None] * len(model)
|
||||||
while work_queue:
|
while work_queue:
|
||||||
xt, i = work_queue.popleft()
|
xt, i = work_queue.popleft()
|
||||||
if i == 0:
|
if i == 0:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
if not decode and patch_size > 1:
|
||||||
|
xt = F.pixel_unshuffle(xt, patch_size)
|
||||||
if i == len(model):
|
if i == len(model):
|
||||||
out.append(xt)
|
if decode and patch_size > 1:
|
||||||
|
xt = F.pixel_shuffle(xt, patch_size)
|
||||||
|
out.append(xt.to(output_device))
|
||||||
del xt
|
del xt
|
||||||
else:
|
else:
|
||||||
b = model[i]
|
b = model[i]
|
||||||
@ -165,24 +176,20 @@ class TAEHV(nn.Module):
|
|||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
if self.patch_size > 1:
|
|
||||||
B, T, C, H, W = x.shape
|
|
||||||
x = x.reshape(B * T, C, H, W)
|
|
||||||
x = F.pixel_unshuffle(x, self.patch_size)
|
|
||||||
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
|
|
||||||
if x.shape[1] % self.t_downscale != 0:
|
if x.shape[1] % self.t_downscale != 0:
|
||||||
# pad at end to multiple of t_downscale
|
# pad at end to multiple of t_downscale
|
||||||
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
||||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||||
x = torch.cat([x, padding], 1)
|
x = torch.cat([x, padding], 1)
|
||||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
|
||||||
|
patch_size=self.patch_size).movedim(2, 1)
|
||||||
return self.process_out(x)
|
return self.process_out(x)
|
||||||
|
|
||||||
def decode(self, x, **kwargs):
|
def decode(self, x, **kwargs):
|
||||||
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
||||||
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
||||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
|
||||||
if self.patch_size > 1:
|
output_device=comfy.model_management.intermediate_device(),
|
||||||
x = F.pixel_shuffle(x, self.patch_size)
|
patch_size=self.patch_size, decode=True)
|
||||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user