ComfyUI/comfy/ldm/flux
rattus128 e42682b24e
Reduce Peak WAN inference VRAM usage (#9898)
* flux: Do the xq and xk ropes one at a time

This was doing independendent interleaved tensor math on the q and k
tensors, leading to the holding of more than the minimum intermediates
in VRAM. On a bad day, it would VRAM OOM on xk intermediates.

Do everything q and then everything k, so torch can garbage collect
all of qs intermediates before k allocates its intermediates.

This reduces peak VRAM usage for some WAN2.2 inferences (at least).

* wan: Optimize qkv intermediates on attention

As commented. The former logic computed independent pieces of QKV in
parallel which help more inference intermediates in VRAM spiking
VRAM usage. Fully roping Q and garbage collecting the intermediates
before touching K reduces the peak inference VRAM usage.
2025-09-16 19:21:14 -04:00
..
controlnet.py Make flux controlnet work with sd3 text enc. (#8599) 2025-06-19 18:50:05 -04:00
layers.py Enable Runtime Selection of Attention Functions (#9639) 2025-09-12 18:07:38 -04:00
math.py Reduce Peak WAN inference VRAM usage (#9898) 2025-09-16 19:21:14 -04:00
model.py Enable Runtime Selection of Attention Functions (#9639) 2025-09-12 18:07:38 -04:00
redux.py Support new flux model variants. 2024-11-21 08:38:23 -05:00