speed up and reduce VRAM of QWEN VAE and WAN (less so) (#12036)
Some checks are pending
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run

* ops: introduce autopad for conv3d

This works around pytorch missing ability to causal pad as part of the
kernel and avoids massive weight duplications for padding.

* wan-vae: rework causal padding

This currently uses F.pad which takes a full deep copy and is liable to
be the VRAM peak. Instead, kick spatial padding back to the op and
consolidate the temporal padding with the cat for the cache.

* wan-vae: implement zero pad fast path

The WAN VAE is also QWEN where it is used single-image. These
convolutions are however zero padded 3d convolutions, which means the
VAE is actually just 2D down the last element of the conv weight in
the temporal dimension. Fast path this, to avoid adding zeros that
then just evaporate in convoluton math but cost computation.
This commit is contained in:
rattus 2026-01-23 16:56:14 -08:00 committed by GitHub
parent 9cf299a9f9
commit 4e6a1b66a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 14 deletions

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self._padding = 2 * self.padding[0]
self.padding[1], 2 * self.padding[0], 0) self.padding = (0, self.padding[1], self.padding[2])
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None: if cache_list is not None:
cache_x = cache_list[cache_idx] cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None cache_list[cache_idx] = None
padding = list(self._padding) if cache_x is None and x.shape[2] == 1:
if cache_x is not None and self._padding[4] > 0: #Fast path - the op will pad for use by truncating the weight
#and save math on a pile of zeros.
return super().forward(x, autopad="causal_zero")
if self._padding > 0:
padding_needed = self._padding
if cache_x is not None:
cache_x = cache_x.to(x.device) cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2) padding_needed = max(0, padding_needed - cache_x.shape[2])
padding[4] -= cache_x.shape[2] padding_shape = list(x.shape)
padding_shape[2] = padding_needed
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
del cache_x del cache_x
x = F.pad(x, padding)
return super().forward(x) return super().forward(x)

View File

@ -203,7 +203,9 @@ class disable_weight_init:
def reset_parameters(self): def reset_parameters(self):
return None return None
def _conv_forward(self, input, weight, bias, *args, **kwargs): def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
if autopad == "causal_zero":
weight = weight[:, :, -input.shape[2]:, :, :]
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None: if bias is not None:
@ -212,15 +214,15 @@ class disable_weight_init:
else: else:
return super()._conv_forward(input, weight, bias, *args, **kwargs) return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input, autopad=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias, autopad=autopad)
uncast_bias_weight(self, weight, bias, offload_stream) uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)