diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 539c0487c..e137ae28d 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -640,13 +640,6 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) - #The meaning of output_buffer == None changes. - #forward now needs to do this for the non start-resume case. - if output_buffer is None: - output_buffer = torch.empty( - self.decode_output_shape(sample.shape), - dtype=sample.dtype, device=comfy.model_management.intermediate_device(), - ) output_offset = [0] max_chunk_size = get_max_chunk_size(sample.device) @@ -673,9 +666,19 @@ class Decoder(nn.Module): #we will design the protocol for signalling end to the caller later. raise NotImplementedError("Decoder.forward_resume is not implemented yet") - def forward(self, *args, **kwargs): + def forward( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, + ): + if output_buffer is None: + output_buffer = torch.empty( + self.decode_output_shape(sample.shape), + dtype=sample.dtype, device=comfy.model_management.intermediate_device(), + ) try: - return self.forward_orig(*args, **kwargs) + return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer) finally: for _, module in self.named_modules(): #ComfyUI doesn't thread this kind of stuff today, but just incase