wan-vae: Switch off feature cache for single frame (#12090)

The code throughout is None safe to just skip the feature cache saving
step if none. Set it none in single frame use so qwen doesn't burn VRAM
on the unused cache.
This commit is contained in:
rattus 2026-01-26 16:40:19 -08:00 committed by GitHub
parent ad53e78f11
commit 6516ab335d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -479,10 +479,12 @@ class WanVAE(nn.Module):
def encode(self, x):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
@ -502,10 +504,11 @@ class WanVAE(nn.Module):
def decode(self, z):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]