mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 08:12:34 +08:00
wan: vae: Don't recursion in local fns (move run_up)
Moved Decoder3d’s recursive run_up out of forward into a class method to avoid nested closure self-reference cycles. This avoids cyclic garbage that delays garbage of tensors which in turn delays VRAM release before tiled fallback.
This commit is contained in:
parent
335532c25f
commit
7a88c578e0
@ -360,6 +360,43 @@ class Decoder3d(nn.Module):
|
|||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||||
|
|
||||||
|
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||||
|
x = x_ref[0]
|
||||||
|
x_ref[0] = None
|
||||||
|
if layer_idx >= len(self.upsamples):
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
|
x = layer(x, feat_cache[feat_idx[0]])
|
||||||
|
feat_cache[feat_idx[0]] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
out_chunks.append(x)
|
||||||
|
return
|
||||||
|
|
||||||
|
layer = self.upsamples[layer_idx]
|
||||||
|
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||||
|
for frame_idx in range(x.shape[2]):
|
||||||
|
self.run_up(
|
||||||
|
layer_idx,
|
||||||
|
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||||||
|
feat_cache,
|
||||||
|
feat_idx.copy(),
|
||||||
|
out_chunks,
|
||||||
|
)
|
||||||
|
del x
|
||||||
|
return
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
next_x_ref = [x]
|
||||||
|
del x
|
||||||
|
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
## conv1
|
## conv1
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
@ -380,42 +417,7 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
out_chunks = []
|
out_chunks = []
|
||||||
|
|
||||||
def run_up(layer_idx, x_ref, feat_idx):
|
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||||
x = x_ref[0]
|
|
||||||
x_ref[0] = None
|
|
||||||
if layer_idx >= len(self.upsamples):
|
|
||||||
for layer in self.head:
|
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
|
||||||
x = layer(x, feat_cache[feat_idx[0]])
|
|
||||||
feat_cache[feat_idx[0]] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
out_chunks.append(x)
|
|
||||||
return
|
|
||||||
|
|
||||||
layer = self.upsamples[layer_idx]
|
|
||||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
|
||||||
for frame_idx in range(x.shape[2]):
|
|
||||||
run_up(
|
|
||||||
layer_idx,
|
|
||||||
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
|
||||||
feat_idx.copy(),
|
|
||||||
)
|
|
||||||
del x
|
|
||||||
return
|
|
||||||
|
|
||||||
if feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
next_x_ref = [x]
|
|
||||||
del x
|
|
||||||
run_up(layer_idx + 1, next_x_ref, feat_idx)
|
|
||||||
|
|
||||||
run_up(0, [x], feat_idx)
|
|
||||||
return out_chunks
|
return out_chunks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user