mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
wan: vae: recurse and chunk for 2+2 frames on decode
Avoid having to clone off slices of 4 frame chunks and reduce the size of the big 6 frame convolutions down to 4. Save the VRAMs.
This commit is contained in:
parent
d8fa68084f
commit
0ebd807eb9
@ -378,24 +378,43 @@ class Decoder3d(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## upsamples
|
out_chunks = []
|
||||||
for layer in self.upsamples:
|
|
||||||
|
def run_head(x, feat_idx):
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
x = layer(x, feat_cache[feat_idx[0]])
|
||||||
|
feat_cache[feat_idx[0]] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
out_chunks.append(x)
|
||||||
|
|
||||||
|
def run_upsamples(layer_idx, x, feat_idx):
|
||||||
|
if layer_idx >= len(self.upsamples):
|
||||||
|
run_head(x, feat_idx)
|
||||||
|
return
|
||||||
|
|
||||||
|
layer = self.upsamples[layer_idx]
|
||||||
|
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||||
|
for frame_idx in range(x.shape[2]):
|
||||||
|
run_upsamples(
|
||||||
|
layer_idx,
|
||||||
|
x[:, :, frame_idx:frame_idx + 1, :, :],
|
||||||
|
feat_idx.copy(),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## head
|
run_upsamples(layer_idx + 1, x, feat_idx)
|
||||||
for layer in self.head:
|
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
run_upsamples(0, x, feat_idx)
|
||||||
idx = feat_idx[0]
|
return out_chunks
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
x = layer(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def count_cache_layers(model):
|
def count_cache_layers(model):
|
||||||
@ -483,5 +502,5 @@ class WanVAE(nn.Module):
|
|||||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||||
feat_cache=feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out += out_
|
||||||
return out
|
return torch.cat(out, 2)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user