wan: encode frames 2x2.

Reduce VRAM usage greatly by encoding frames 2 at a time rather than
4.
This commit is contained in:
Rattus 2026-03-16 21:57:26 +10:00
parent 0ebd807eb9
commit 6c4843567d

View File

@ -458,11 +458,12 @@ class WanVAE(nn.Module):
conv_idx = [0]
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4....
## 对encode输入的x按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
for i in range(iter_):
conv_idx = [0]
if i == 0:
@ -472,7 +473,7 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
final=(i == (iter_ - 1)))