ltx: vae: Move decoder output buffer allocation back to forward.

So none can mean none.
This commit is contained in:
Rattus 2026-04-08 09:00:50 +10:00
parent 9dede56fac
commit ad91467ef6

View File

@ -640,13 +640,6 @@ class Decoder(nn.Module):
) )
timestep_shift_scale = ada_values.unbind(dim=1) timestep_shift_scale = ada_values.unbind(dim=1)
#The meaning of output_buffer == None changes.
#forward now needs to do this for the non start-resume case.
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0] output_offset = [0]
max_chunk_size = get_max_chunk_size(sample.device) max_chunk_size = get_max_chunk_size(sample.device)
@ -673,9 +666,19 @@ class Decoder(nn.Module):
#we will design the protocol for signalling end to the caller later. #we will design the protocol for signalling end to the caller later.
raise NotImplementedError("Decoder.forward_resume is not implemented yet") raise NotImplementedError("Decoder.forward_resume is not implemented yet")
def forward(self, *args, **kwargs): def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
):
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
try: try:
return self.forward_orig(*args, **kwargs) return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
finally: finally:
for _, module in self.named_modules(): for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase #ComfyUI doesn't thread this kind of stuff today, but just incase