mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-25 02:53:29 +08:00
Fix VRAM leak in tiler fallback in video VAEs (#13073)
* sd: soft_empty_cache on tiler fallback This doesnt cost a lot and creates the expected VRAM reduction in resource monitors when you fallback to tiler. * wan: vae: Don't recursion in local fns (move run_up) Moved Decoder3d’s recursive run_up out of forward into a class method to avoid nested closure self-reference cycles. This avoids cyclic garbage that delays garbage of tensors which in turn delays VRAM release before tiled fallback. * ltx: vae: Don't recursion in local fns (move run_up) Mov the recursive run_up out of forward into a class method to avoid nested closure self-reference cycles. This avoids cyclic garbage that delays garbage of tensors which in turn delays VRAM release before tiled fallback.
This commit is contained in:
parent
8458ae2686
commit
82b868a45a
@ -536,6 +536,53 @@ class Decoder(nn.Module):
|
|||||||
c, (ts, hs, ws), to = self._output_scale
|
c, (ts, hs, ws), to = self._output_scale
|
||||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||||
|
|
||||||
|
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||||
|
sample = sample_ref[0]
|
||||||
|
sample_ref[0] = None
|
||||||
|
if idx >= len(self.up_blocks):
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
if timestep_shift_scale is not None:
|
||||||
|
shift, scale = timestep_shift_scale
|
||||||
|
sample = sample * (1 + scale) + shift
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
if ended:
|
||||||
|
mark_conv3d_ended(self.conv_out)
|
||||||
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
|
if sample is not None and sample.shape[2] > 0:
|
||||||
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
t = sample.shape[2]
|
||||||
|
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||||
|
output_offset[0] += t
|
||||||
|
return
|
||||||
|
|
||||||
|
up_block = self.up_blocks[idx]
|
||||||
|
if ended:
|
||||||
|
mark_conv3d_ended(up_block)
|
||||||
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
|
sample = checkpoint_fn(up_block)(
|
||||||
|
sample, causal=self.causal, timestep=scaled_timestep
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||||
|
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
total_bytes = sample.numel() * sample.element_size()
|
||||||
|
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||||
|
|
||||||
|
if num_chunks == 1:
|
||||||
|
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||||
|
next_sample_ref = [sample]
|
||||||
|
del sample
|
||||||
|
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
|
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
@ -591,54 +638,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
max_chunk_size = get_max_chunk_size(sample.device)
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
|
|
||||||
def run_up(idx, sample_ref, ended):
|
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
sample = sample_ref[0]
|
|
||||||
sample_ref[0] = None
|
|
||||||
if idx >= len(self.up_blocks):
|
|
||||||
sample = self.conv_norm_out(sample)
|
|
||||||
if timestep_shift_scale is not None:
|
|
||||||
shift, scale = timestep_shift_scale
|
|
||||||
sample = sample * (1 + scale) + shift
|
|
||||||
sample = self.conv_act(sample)
|
|
||||||
if ended:
|
|
||||||
mark_conv3d_ended(self.conv_out)
|
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
|
||||||
if sample is not None and sample.shape[2] > 0:
|
|
||||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
||||||
t = sample.shape[2]
|
|
||||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
|
||||||
output_offset[0] += t
|
|
||||||
return
|
|
||||||
|
|
||||||
up_block = self.up_blocks[idx]
|
|
||||||
if (ended):
|
|
||||||
mark_conv3d_ended(up_block)
|
|
||||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
|
||||||
sample = checkpoint_fn(up_block)(
|
|
||||||
sample, causal=self.causal, timestep=scaled_timestep
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
|
||||||
|
|
||||||
if sample is None or sample.shape[2] == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
total_bytes = sample.numel() * sample.element_size()
|
|
||||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
|
||||||
|
|
||||||
if num_chunks == 1:
|
|
||||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
|
||||||
next_sample_ref = [sample]
|
|
||||||
del sample
|
|
||||||
run_up(idx + 1, next_sample_ref, ended)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
|
||||||
|
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
|
||||||
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
|
||||||
|
|
||||||
run_up(0, [sample], True)
|
|
||||||
|
|
||||||
return output_buffer
|
return output_buffer
|
||||||
|
|
||||||
|
|||||||
@ -360,6 +360,43 @@ class Decoder3d(nn.Module):
|
|||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||||
|
|
||||||
|
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||||
|
x = x_ref[0]
|
||||||
|
x_ref[0] = None
|
||||||
|
if layer_idx >= len(self.upsamples):
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
|
x = layer(x, feat_cache[feat_idx[0]])
|
||||||
|
feat_cache[feat_idx[0]] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
out_chunks.append(x)
|
||||||
|
return
|
||||||
|
|
||||||
|
layer = self.upsamples[layer_idx]
|
||||||
|
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||||
|
for frame_idx in range(x.shape[2]):
|
||||||
|
self.run_up(
|
||||||
|
layer_idx,
|
||||||
|
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||||||
|
feat_cache,
|
||||||
|
feat_idx.copy(),
|
||||||
|
out_chunks,
|
||||||
|
)
|
||||||
|
del x
|
||||||
|
return
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
next_x_ref = [x]
|
||||||
|
del x
|
||||||
|
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
## conv1
|
## conv1
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
@ -380,42 +417,7 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
out_chunks = []
|
out_chunks = []
|
||||||
|
|
||||||
def run_up(layer_idx, x_ref, feat_idx):
|
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||||
x = x_ref[0]
|
|
||||||
x_ref[0] = None
|
|
||||||
if layer_idx >= len(self.upsamples):
|
|
||||||
for layer in self.head:
|
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
|
||||||
x = layer(x, feat_cache[feat_idx[0]])
|
|
||||||
feat_cache[feat_idx[0]] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
out_chunks.append(x)
|
|
||||||
return
|
|
||||||
|
|
||||||
layer = self.upsamples[layer_idx]
|
|
||||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
|
||||||
for frame_idx in range(x.shape[2]):
|
|
||||||
run_up(
|
|
||||||
layer_idx,
|
|
||||||
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
|
||||||
feat_idx.copy(),
|
|
||||||
)
|
|
||||||
del x
|
|
||||||
return
|
|
||||||
|
|
||||||
if feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
next_x_ref = [x]
|
|
||||||
del x
|
|
||||||
run_up(layer_idx + 1, next_x_ref, feat_idx)
|
|
||||||
|
|
||||||
run_up(0, [x], feat_idx)
|
|
||||||
return out_chunks
|
return out_chunks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -978,6 +978,7 @@ class VAE:
|
|||||||
do_tile = True
|
do_tile = True
|
||||||
|
|
||||||
if do_tile:
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
@ -1059,6 +1060,7 @@ class VAE:
|
|||||||
do_tile = True
|
do_tile = True
|
||||||
|
|
||||||
if do_tile:
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3:
|
||||||
tile = 256
|
tile = 256
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user