mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-09 16:52:32 +08:00
ltx: vae: implement chunked encoder + CPU IO chunking
People are doing things with big frame counts in LTX including V2V flows. Implement the time-chunked encoder to keep the VRAM down, with the converse of the new CPU pre-allocation technique, where the chunks are brought from the CPU JIT.
This commit is contained in:
parent
d0db3bb104
commit
60a3de0ef3
@ -233,10 +233,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||||
r"""The forward method of the `Encoder` class."""
|
|
||||||
|
|
||||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
checkpoint_fn = (
|
checkpoint_fn = (
|
||||||
@ -247,10 +244,14 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
for down_block in self.down_blocks:
|
for down_block in self.down_blocks:
|
||||||
sample = checkpoint_fn(down_block)(sample)
|
sample = checkpoint_fn(down_block)(sample)
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample)
|
sample = self.conv_out(sample)
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
if self.latent_log_var == "uniform":
|
if self.latent_log_var == "uniform":
|
||||||
last_channel = sample[:, -1:, ...]
|
last_channel = sample[:, -1:, ...]
|
||||||
@ -282,9 +283,29 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||||
|
r"""The forward method of the `Encoder` class."""
|
||||||
|
|
||||||
|
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||||
|
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||||
|
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
samples = [sample[:, :, :1, :, :]]
|
||||||
|
if sample.shape[2] > 1:
|
||||||
|
n = max(1, max_chunk_size // (2 * frame_size))
|
||||||
|
samples += list(torch.split(sample[:, :, 1:, :, :], 2 * n, dim=2))
|
||||||
|
for chunk_idx, chunk in enumerate(samples):
|
||||||
|
if chunk_idx == len(samples) - 1:
|
||||||
|
mark_conv3d_ended(self)
|
||||||
|
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||||
|
output = self._forward_chunk(chunk)
|
||||||
|
if output is not None:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
return torch_cat_if_needed(outputs, dim=2)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
#No encoder support so just flag the end so it doesnt use the cache.
|
|
||||||
mark_conv3d_ended(self)
|
|
||||||
try:
|
try:
|
||||||
return self.forward_orig(*args, **kwargs)
|
return self.forward_orig(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
@ -1266,9 +1287,9 @@ class VideoVAE(nn.Module):
|
|||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x, device=None):
|
||||||
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode_output_shape(self, input_shape):
|
def decode_output_shape(self, input_shape):
|
||||||
|
|||||||
@ -1038,8 +1038,13 @@ class VAE:
|
|||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = None
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||||
|
else:
|
||||||
|
pixels_in = pixels_in.to(self.device)
|
||||||
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
if samples is None:
|
if samples is None:
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user