From 1c2d37944c61ef8d421f8a782f3edb80c392dc9f Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 14 Apr 2026 11:15:47 +1000 Subject: [PATCH] nodes_image_stream: implement VAE decoder node --- .../vae/causal_video_autoencoder.py | 9 +++ comfy/sd.py | 11 +++ comfy_extras/nodes_image_stream.py | 81 +++++++++++++++++++ 3 files changed, 101 insertions(+) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index c0211addd..91606ffa6 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1388,6 +1388,15 @@ class VideoVAE(nn.Module): def decode_output_shape(self, input_shape): return self.decoder.decode_output_shape(input_shape) + def decode_start(self, x): + clear_temporal_cache_state(self.decoder) + if self.timestep_conditioning: #TODO: seed + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) + + def decode_chunk(self, output_t: int): + return self.decoder.forward_resume(output_t) + def decode(self, x, output_buffer=None): if self.timestep_conditioning: #TODO: seed x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..9e442772e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -995,6 +995,17 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples + def decode_output_shape(self, samples_shape): + self.throw_exception_if_invalid() + if hasattr(self.first_stage_model, "decode_output_shape"): + return self.first_stage_model.decode_output_shape(samples_shape) + raise RuntimeError("This VAE does not expose decode output shape information.") + + def decode_stream_start(self, samples_in): + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype)) + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile diff --git a/comfy_extras/nodes_image_stream.py b/comfy_extras/nodes_image_stream.py index fd2ac6915..f01075887 100644 --- a/comfy_extras/nodes_image_stream.py +++ b/comfy_extras/nodes_image_stream.py @@ -123,6 +123,60 @@ class PreviewingImageStream(Input.ImageStream): return chunk +class VAEDecodedImageStream(Input.ImageStream): + def __init__(self, vae, latent: Input.Latent): + super().__init__() + self._vae = vae + self._latent = latent + vae.throw_exception_if_invalid() + if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False): + raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.") + if latent.ndim != 5: + raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].") + if latent.shape[0] != 1: + raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.") + output_shape = vae.decode_output_shape(latent.shape) + self._channels = int(output_shape[1]) + self._width = int(output_shape[4]) + self._height = int(output_shape[3]) + self._total_frames = int(output_shape[0] * output_shape[2]) + self._frames_emitted = 0 + + def _update_progress(self, value: int) -> None: + current = get_executing_context() + if current is None: + return + get_progress_state().update_progress( + current.node_id, + value=float(value), + max_value=float(max(self._total_frames, 1)), + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._width, self._height + + def do_reset(self) -> None: + self._frames_emitted = 0 + self._update_progress(0) + self._vae.decode_stream_start(self._latent) + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._vae.first_stage_model.decode_chunk(max_frames) + if chunk is None: + return torch.empty( + (0, self._height, self._width, self._channels), + device=self._vae.output_device, + dtype=self._vae.vae_output_dtype(), + ) + + chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype()) + chunk = self._vae.process_output(chunk).movedim(1, -1) + chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:])) + self._frames_emitted += int(chunk.shape[0]) + self._update_progress(self._frames_emitted) + return chunk + + class ImageBatchToStream(io.ComfyNode): @classmethod def define_schema(cls): @@ -178,6 +232,32 @@ class ImageStreamToBatch(io.ComfyNode): return io.NodeOutput(torch.cat(chunks, dim=0)) +class VAEDecodeStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VAEDecodeStream", + display_name="VAE Decode Stream", + category="image/stream", + search_aliases=["vae stream decode", "latent to stream", "video latent stream"], + description="Decodes a latent into an IMAGE_STREAM.", + inputs=[ + io.Latent.Input("samples", tooltip="The LTX latent to decode."), + io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."), + ], + outputs=[ + io.ImageStream.Output(display_name="stream"), + ], + ) + + @classmethod + def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput: + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + return io.NodeOutput(VAEDecodedImageStream(vae, latent)) + + class PreviewImageStream(io.ComfyNode): @classmethod def define_schema(cls): @@ -230,6 +310,7 @@ class ImageStreamExtension(ComfyExtension): return [ ImageBatchToStream, ImageStreamToBatch, + VAEDecodeStream, PreviewImageStream, StreamSink, ]