wan: vae: Fix light/color change

There was an issue where the resample split was too early and dropped one
of the rolling convolutions a frame early. This is most noticable as a
lighting/color change between pixel frames 5->6 (latent 2->3), or as a
lighting change between the first and last frame in an FLF wan flow.
This commit is contained in:
Rattus 2026-03-22 07:34:17 +10:00
parent 11c15d8832
commit 182e0bae11

View File

@ -376,11 +376,16 @@ class Decoder3d(nn.Module):
return return
layer = self.upsamples[layer_idx] layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: if feat_cache is not None:
for frame_idx in range(x.shape[2]): x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
for frame_idx in range(0, x.shape[2], 2):
self.run_up( self.run_up(
layer_idx, layer_idx + 1,
[x[:, :, frame_idx:frame_idx + 1, :, :]], [x[:, :, frame_idx:frame_idx + 2, :, :]],
feat_cache, feat_cache,
feat_idx.copy(), feat_idx.copy(),
out_chunks, out_chunks,
@ -388,11 +393,6 @@ class Decoder3d(nn.Module):
del x del x
return return
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
next_x_ref = [x] next_x_ref = [x]
del x del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)