wan: vae: restyle a little to match LTX

This commit is contained in:
Rattus 2026-03-17 16:21:16 +10:00
parent 8658c570f9
commit 0f5621ee32

View File

@ -380,28 +380,25 @@ class Decoder3d(nn.Module):
out_chunks = [] out_chunks = []
def run_head(x, feat_idx): def run_up(layer_idx, x_ref, feat_idx):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
def run_upsamples(layer_idx, x_ref, feat_idx):
x = x_ref[0] x = x_ref[0]
x_ref[0] = None x_ref[0] = None
if layer_idx >= len(self.upsamples): if layer_idx >= len(self.upsamples):
run_head(x, feat_idx) for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return return
layer = self.upsamples[layer_idx] layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]): for frame_idx in range(x.shape[2]):
run_upsamples( run_up(
layer_idx, layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]], [x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(), feat_idx.copy(),
@ -416,9 +413,9 @@ class Decoder3d(nn.Module):
next_x_ref = [x] next_x_ref = [x]
del x del x
run_upsamples(layer_idx + 1, next_x_ref, feat_idx) run_up(layer_idx + 1, next_x_ref, feat_idx)
run_upsamples(0, [x], feat_idx) run_up(0, [x], feat_idx)
return out_chunks return out_chunks