mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 22:00:19 +08:00
speed up and reduce VRAM of QWEN VAE and WAN (less so) (#12036)
Some checks are pending
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Some checks are pending
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
* ops: introduce autopad for conv3d This works around pytorch missing ability to causal pad as part of the kernel and avoids massive weight duplications for padding. * wan-vae: rework causal padding This currently uses F.pad which takes a full deep copy and is liable to be the VRAM peak. Instead, kick spatial padding back to the op and consolidate the temporal padding with the cat for the cache. * wan-vae: implement zero pad fast path The WAN VAE is also QWEN where it is used single-image. These convolutions are however zero padded 3d convolutions, which means the VAE is actually just 2D down the last element of the conv weight in the temporal dimension. Fast path this, to avoid adding zeros that then just evaporate in convoluton math but cost computation.
This commit is contained in:
parent
9cf299a9f9
commit
4e6a1b66a9
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
self._padding = 2 * self.padding[0]
|
||||||
self.padding[1], 2 * self.padding[0], 0)
|
self.padding = (0, self.padding[1], self.padding[2])
|
||||||
self.padding = (0, 0, 0)
|
|
||||||
|
|
||||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||||
if cache_list is not None:
|
if cache_list is not None:
|
||||||
cache_x = cache_list[cache_idx]
|
cache_x = cache_list[cache_idx]
|
||||||
cache_list[cache_idx] = None
|
cache_list[cache_idx] = None
|
||||||
|
|
||||||
padding = list(self._padding)
|
if cache_x is None and x.shape[2] == 1:
|
||||||
if cache_x is not None and self._padding[4] > 0:
|
#Fast path - the op will pad for use by truncating the weight
|
||||||
cache_x = cache_x.to(x.device)
|
#and save math on a pile of zeros.
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
return super().forward(x, autopad="causal_zero")
|
||||||
padding[4] -= cache_x.shape[2]
|
|
||||||
|
if self._padding > 0:
|
||||||
|
padding_needed = self._padding
|
||||||
|
if cache_x is not None:
|
||||||
|
cache_x = cache_x.to(x.device)
|
||||||
|
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||||||
|
padding_shape = list(x.shape)
|
||||||
|
padding_shape[2] = padding_needed
|
||||||
|
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||||||
|
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||||||
del cache_x
|
del cache_x
|
||||||
x = F.pad(x, padding)
|
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
|
|||||||
10
comfy/ops.py
10
comfy/ops.py
@ -203,7 +203,9 @@ class disable_weight_init:
|
|||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
|
||||||
|
if autopad == "causal_zero":
|
||||||
|
weight = weight[:, :, -input.shape[2]:, :, :]
|
||||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
@ -212,15 +214,15 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input, autopad=None):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
x = self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias, autopad=autopad)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user