From 1f0c02eb6c3ae0735006c53db2a258c8588e7a53 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 16:40:12 +1000 Subject: [PATCH 01/17] comfy_api: Add datatype for ImageStreams --- comfy/comfy_types/node_typing.py | 1 + comfy_api/input/__init__.py | 2 + comfy_api/input/image_stream_types.py | 6 +++ comfy_api/latest/__init__.py | 3 +- comfy_api/latest/_input/__init__.py | 2 + comfy_api/latest/_input/image_stream_types.py | 48 +++++++++++++++++++ comfy_api/latest/_io.py | 9 +++- 7 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 comfy_api/input/image_stream_types.py create mode 100644 comfy_api/latest/_input/image_stream_types.py diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 57126fa4a..18eba6dc5 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -51,6 +51,7 @@ class IO(StrEnum): BBOX = "BBOX" SEGS = "SEGS" VIDEO = "VIDEO" + IMAGE_STREAM = "IMAGE_STREAM" ANY = "*" """Always matches any type, but at a price. diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 16d4acfd1..8e2374aaf 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -2,6 +2,7 @@ from comfy_api.latest._input import ( ImageInput, AudioInput, + ImageStreamInput, MaskInput, LatentInput, VideoInput, @@ -14,6 +15,7 @@ from comfy_api.latest._input import ( __all__ = [ "ImageInput", "AudioInput", + "ImageStreamInput", "MaskInput", "LatentInput", "VideoInput", diff --git a/comfy_api/input/image_stream_types.py b/comfy_api/input/image_stream_types.py new file mode 100644 index 000000000..b52d0c76d --- /dev/null +++ b/comfy_api/input/image_stream_types.py @@ -0,0 +1,6 @@ +# This file only exists for backwards compatibility. +from comfy_api.latest._input.image_stream_types import ImageStreamInput + +__all__ = [ + "ImageStreamInput", +] diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 04973fea0..c0493b3ca 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class -from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input import ImageInput, AudioInput, ImageStreamInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D from . import _io_public as io @@ -131,6 +131,7 @@ class ComfyExtension(ABC): class Input: Image = ImageInput Audio = AudioInput + ImageStream = ImageStreamInput Mask = MaskInput Latent = LatentInput Video = VideoInput diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 05cd3d40a..3ec611879 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,10 +1,12 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve +from .image_stream_types import ImageStreamInput from .video_types import VideoInput __all__ = [ "ImageInput", "AudioInput", + "ImageStreamInput", "VideoInput", "MaskInput", "LatentInput", diff --git a/comfy_api/latest/_input/image_stream_types.py b/comfy_api/latest/_input/image_stream_types.py new file mode 100644 index 000000000..04af3562c --- /dev/null +++ b/comfy_api/latest/_input/image_stream_types.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from .basic_types import ImageInput + + +class ImageStreamInput(ABC): + """Abstract base class for pull-based image stream inputs. + + Consumers request up to ``max_frames`` frames at a time. Producers must not + over-return; a batch with fewer than ``max_frames`` frames signals EOF. + """ + + def __init__(self): + #Subclasses must call this init for future core ComfyUI change compatibilty + pass + + def reset(self) -> None: + #This API is final. Subclasses must NOT override this for future core ComfyUI + #change compatability. Override do_reset instead. + return self.do_reset() + + def pull(self, max_frames: int) -> ImageInput: + #This API is final. Subclasses must NOT override this for future core ComfyUI + #change compatability. Override do_pull instead. + return self.do_pull(max_frames) + + @abstractmethod + def get_dimensions(self) -> tuple[int, int]: + """Return the stream frame dimensions as ``(width, height)``.""" + pass + + @abstractmethod + def do_reset(self) -> None: + """Reset the stream so the next pull starts from frame 0.""" + pass + + @abstractmethod + def do_pull(self, max_frames: int) -> ImageInput: + """Return up to ``max_frames`` images. + + The returned tensor uses the normal ``IMAGE`` batch shape. A short + return, where the batch dimension is less than ``max_frames``, is the + EOF signal. Sources are expected to short-return at least once before + exhaustion, including returning an empty batch. + """ + pass diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index fdeffea2d..3fbd9e40e 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from comfy.samplers import CFGGuider, Sampler from comfy.sd import CLIP, VAE from comfy.sd import StyleModel as StyleModel_ - from comfy_api.input import VideoInput, CurveInput as CurveInput_ + from comfy_api.input import ImageStreamInput, VideoInput, CurveInput as CurveInput_ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker @@ -420,6 +420,12 @@ class Image(ComfyTypeIO): Type = torch.Tensor +@comfytype(io_type="IMAGE_STREAM") +class ImageStream(ComfyTypeIO): + if TYPE_CHECKING: + Type = ImageStreamInput + + @comfytype(io_type="WAN_CAMERA_EMBEDDING") class WanCameraEmbedding(ComfyTypeIO): Type = torch.Tensor @@ -2203,6 +2209,7 @@ __all__ = [ "Combo", "MultiCombo", "Image", + "ImageStream", "WanCameraEmbedding", "Webcam", "Mask", From 7dc366adc794ba539bc0ebd534fcff0558138e4b Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 9 Apr 2026 20:32:28 +1000 Subject: [PATCH 02/17] progress --- comfy_api/latest/_input/image_stream_types.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_input/image_stream_types.py b/comfy_api/latest/_input/image_stream_types.py index 04af3562c..a582859ad 100644 --- a/comfy_api/latest/_input/image_stream_types.py +++ b/comfy_api/latest/_input/image_stream_types.py @@ -1,7 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import nullcontext +from comfy_execution.utils import CurrentNodeContext, get_executing_context +from comfy_execution.progress import get_progress_state from .basic_types import ImageInput @@ -14,17 +17,33 @@ class ImageStreamInput(ABC): def __init__(self): #Subclasses must call this init for future core ComfyUI change compatibilty - pass + self._ctx = get_executing_context() def reset(self) -> None: #This API is final. Subclasses must NOT override this for future core ComfyUI #change compatability. Override do_reset instead. - return self.do_reset() + with (nullcontext() if self._ctx is None else + CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)): + self.do_reset() + + if self._ctx is not None: + get_progress_state().finish_progress(self._ctx.node_id) def pull(self, max_frames: int) -> ImageInput: #This API is final. Subclasses must NOT override this for future core ComfyUI #change compatability. Override do_pull instead. - return self.do_pull(max_frames) + with (nullcontext() if self._ctx is None else + CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)): + result = self.do_pull(max_frames) + + if self._ctx is not None: + registry = get_progress_state() + entry = registry.nodes.get(self._ctx.node_id) + if (int(result.shape[0]) < max_frames or + (entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])): + registry.finish_progress(self._ctx.node_id) + + return result @abstractmethod def get_dimensions(self) -> tuple[int, int]: From 7b1d4bcdf6c7ef5da65a7199a28937be61612e94 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 18:25:27 +1000 Subject: [PATCH 03/17] node_image_stream: add Add some nodes getting into and out of stream mode. --- comfy_extras/nodes_image_stream.py | 174 +++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 175 insertions(+) create mode 100644 comfy_extras/nodes_image_stream.py diff --git a/comfy_extras/nodes_image_stream.py b/comfy_extras/nodes_image_stream.py new file mode 100644 index 000000000..12ba1d908 --- /dev/null +++ b/comfy_extras/nodes_image_stream.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, Input, io, ui +from comfy_execution.progress import get_progress_state +from comfy_execution.utils import get_executing_context +from server import PromptServer + + +class FrameProgressTracker: + def __init__(self): + self._last_reported=None + + def emit(self, frames_processed): + if frames_processed == self._last_reported: + return + + current = get_executing_context() + server = getattr(PromptServer, "instance", None) + if current is None or server is None or server.client_id is None: + return + + server.send_progress_text( + f"processed {frames_processed} frames", + current.node_id, + server.client_id, + ) + self._last_reported = frames_processed + + +def drain_image_stream(stream, chunk_size, progress): + stream.reset() + frames_processed = 0 + + while True: + progress.emit(frames_processed) + chunk = stream.pull(chunk_size) + frames_processed += int(chunk.shape[0]) + if chunk.shape[0] < chunk_size: + progress.emit(frames_processed) + return frames_processed + +class TensorImageStream(Input.ImageStream): + """Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor.""" + + def __init__(self, images: Input.Image): + super().__init__() + self._images = images + self._index = 0 + self._total_frames = int(images.shape[0]) + self._progress_started = False + + def _update_progress(self, value: int) -> None: + current = get_executing_context() + if current is None: + return + get_progress_state().update_progress( + current.node_id, + value=float(value), + max_value=float(max(self._total_frames, 1)), + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._images.shape[2], self._images.shape[1] + + def do_reset(self) -> None: + self._index = 0 + self._progress_started = False + + def do_pull(self, max_frames: int) -> Input.Image: + if not self._progress_started: + self._update_progress(0) + self._progress_started = True + + start = self._index + end = min(start + max_frames, self._images.shape[0]) + self._index = end + chunk = self._images[start:end].clone() + self._update_progress(end) + return chunk + + +class ImageBatchToStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBatchToStream", + display_name="Image Batch To Stream", + category="image/stream", + search_aliases=["image to stream", "batch to stream", "frames to stream"], + description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.", + inputs=[ + io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."), + ], + outputs=[ + io.ImageStream.Output(display_name="stream"), + ], + ) + + @classmethod + def execute(cls, image: Input.Image) -> io.NodeOutput: + return io.NodeOutput(TensorImageStream(image)) + + +class ImageStreamToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageStreamToBatch", + display_name="Image Stream To Batch", + category="image/stream", + search_aliases=["stream to image", "stream to batch", "collect stream"], + description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.", + inputs=[ + io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."), + io.Int.Input("batch_size", default=4096, min=1, max=4096), + ], + outputs=[ + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput: + chunks: list[Input.Image] = [] + stream.reset() + + while True: + chunk = stream.pull(batch_size) + + chunks.append(chunk) + if chunk.shape[0] < batch_size: + break + + return io.NodeOutput(torch.cat(chunks, dim=0)) + + +class StreamSink(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StreamSink", + search_aliases=["consume stream", "drain stream", "image stream sink"], + display_name="Stream Sink", + category="image/stream", + description="Consumes an IMAGE_STREAM by pulling it to EOF.", + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to consume."), + io.Int.Input("chunk_size", default=8, min=1, max=4096), + ], + outputs=[], + is_output_node=True, + ) + + @classmethod + def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput: + drain_image_stream(stream, chunk_size, progress=FrameProgressTracker()) + return io.NodeOutput() + + +class ImageStreamExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ImageBatchToStream, + ImageStreamToBatch, + StreamSink, + ] + + +async def comfy_entrypoint() -> ImageStreamExtension: + return ImageStreamExtension() diff --git a/nodes.py b/nodes.py index 299b3d758..66ab31ad3 100644 --- a/nodes.py +++ b/nodes.py @@ -2414,6 +2414,7 @@ async def init_builtin_extra_nodes(): "nodes_hooks.py", "nodes_load_3d.py", "nodes_cosmos.py", + "nodes_image_stream.py", "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", From 37c578f2dd85dea902d475997f50adf81093c478 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 18:40:06 +1000 Subject: [PATCH 04/17] nodes_upscale_model: add stream support --- comfy_extras/nodes_upscale_model.py | 44 +++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1..e2ae24422 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -5,8 +5,7 @@ import torch import comfy.utils import folder_paths from typing_extensions import override -from comfy_api.latest import ComfyExtension, io -import comfy.model_management +from comfy_api.latest import ComfyExtension, Input, io try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -47,9 +46,29 @@ class UpscaleModelLoader(io.ComfyNode): load_model = execute # TODO: remove +class UpscaledImageStream(Input.ImageStream): + def __init__(self, upscale_model, stream: Input.ImageStream): + super().__init__() + self._upscale_model = upscale_model + self._stream = stream + + def get_dimensions(self) -> tuple[int, int]: + width, height = self._stream.get_dimensions() + scale = self._upscale_model.scale + return int(width * scale), int(height * scale) + + def do_reset(self) -> None: + self._stream.reset() + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._stream.pull(max_frames) + return ImageUpscaleWithModel.upscale_batch(self._upscale_model, chunk) + + class ImageUpscaleWithModel(io.ComfyNode): @classmethod def define_schema(cls): + image_template = io.MatchType.Template("image_type", allowed_types=[io.Image, io.ImageStream]) return io.Schema( node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", @@ -57,15 +76,18 @@ class ImageUpscaleWithModel(io.ComfyNode): search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), - io.Image.Input("image"), + io.MatchType.Input("image", template=image_template), ], outputs=[ - io.Image.Output(), + io.MatchType.Output(template=image_template, display_name="image"), ], ) @classmethod - def execute(cls, upscale_model, image) -> io.NodeOutput: + def upscale_batch(cls, upscale_model, image: torch.Tensor) -> torch.Tensor: + if image.shape[0] == 0: + return image.clone() + device = model_management.get_torch_device() memory_required = model_management.module_size(upscale_model.model) @@ -79,7 +101,7 @@ class ImageUpscaleWithModel(io.ComfyNode): tile = 512 overlap = 32 - output_device = comfy.model_management.intermediate_device() + output_device = model_management.intermediate_device() oom = True try: @@ -97,8 +119,14 @@ class ImageUpscaleWithModel(io.ComfyNode): finally: upscale_model.to("cpu") - s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) - return io.NodeOutput(s) + return torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(model_management.intermediate_dtype()) + + @classmethod + def execute(cls, upscale_model, image) -> io.NodeOutput: + if isinstance(image, torch.Tensor): + return io.NodeOutput(cls.upscale_batch(upscale_model, image)) + + return io.NodeOutput(UpscaledImageStream(upscale_model, image)) upscale = execute # TODO: remove From 7719eb68774a1c95250c179c7bc638f409b9c562 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 13 Apr 2026 22:28:01 +1000 Subject: [PATCH 05/17] video_types: split up saver to streamable state Split this up into start -> chunk -> finish so it can be saved piece by piece. --- comfy_api/latest/_input_impl/video_types.py | 52 ++++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 1b4993aa7..4cde7763a 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -386,6 +386,7 @@ class VideoFromComponents(VideoInput): def __init__(self, components: VideoComponents): self.__components = components + self._frame_counter = 0 def get_components(self) -> VideoComponents: return VideoComponents( @@ -394,14 +395,13 @@ class VideoFromComponents(VideoInput): frame_rate=self.__components.frame_rate, ) - def save_to( + def save_start( self, path: str, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): - """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -413,7 +413,12 @@ class VideoFromComponents(VideoInput): # BytesIO has no file extension, so av.open can't infer the format. # Default to mp4 since that's the only supported format anyway. extra_kwargs["format"] = "mp4" - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: + + width, height = self.get_dimensions() + + output = av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) + if True: + # Add metadata before writing any streams if metadata is not None: for key, value in metadata.items(): @@ -422,8 +427,8 @@ class VideoFromComponents(VideoInput): frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) # Create a video stream video_stream = output.add_stream('h264', rate=frame_rate) - video_stream.width = self.__components.images.shape[2] - video_stream.height = self.__components.images.shape[1] + video_stream.width = width + video_stream.height = height video_stream.pix_fmt = 'yuv420p' # Create an audio stream @@ -432,23 +437,33 @@ class VideoFromComponents(VideoInput): if self.__components.audio: audio_sample_rate = int(self.__components.audio['sample_rate']) waveform = self.__components.audio['waveform'] - waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + waveform = waveform[0] layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout) + self._frame_counter = 0 + return output, video_stream, audio_stream, audio_sample_rate, frame_rate + + def save_add(self, output, video_stream, images) -> None: # Encode video - for i, frame in enumerate(self.__components.images): + for frame in images: img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) frame = av.VideoFrame.from_ndarray(img, format='rgb24') frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 packet = video_stream.encode(frame) output.mux(packet) + self._frame_counter += 1 + def save_finalize(self, output, video_stream, audio_stream, audio_sample_rate, frame_rate) -> None: # Flush video packet = video_stream.encode(None) output.mux(packet) if audio_stream and self.__components.audio: + waveform = self.__components.audio['waveform'] + waveform = waveform[0] + layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') + waveform = waveform[:, :math.ceil((audio_sample_rate / frame_rate) * self._frame_counter)] frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout) frame.sample_rate = audio_sample_rate frame.pts = 0 @@ -457,6 +472,29 @@ class VideoFromComponents(VideoInput): # Flush encoder output.mux(audio_stream.encode(None)) + output.close() + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None, + ): + """Save the video to a file path or BytesIO buffer.""" + output, video_stream, audio_stream, audio_sample_rate, frame_rate = self.save_start( + path, + format=format, + codec=codec, + metadata=metadata, + ) + try: + self.save_add(output, video_stream, self.__components.images) + self.save_finalize(output, video_stream, audio_stream, audio_sample_rate, frame_rate) + except Exception: + output.close() + raise + def as_trimmed( self, start_time: float | None = None, From 05b5240d7091d311016e19b54eea9a2902bea1d7 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 13 Apr 2026 22:30:57 +1000 Subject: [PATCH 06/17] nodes_image_stream: implement stream live preview node --- comfy_extras/nodes_image_stream.py | 65 ++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/comfy_extras/nodes_image_stream.py b/comfy_extras/nodes_image_stream.py index 12ba1d908..fd2ac6915 100644 --- a/comfy_extras/nodes_image_stream.py +++ b/comfy_extras/nodes_image_stream.py @@ -82,6 +82,47 @@ class TensorImageStream(Input.ImageStream): return chunk +class PreviewingImageStream(Input.ImageStream): + def __init__(self, stream: Input.ImageStream): + super().__init__() + self._stream = stream + + def _emit_preview(self, chunk: Input.Image) -> None: + if int(chunk.shape[0]) == 0: + return + + current = get_executing_context() + if current is None: + return + + server = getattr(PromptServer, "instance", None) + if server is None or server.client_id is None: + return + + preview_output = ui.PreviewImage(chunk[-1:]).as_dict() + server.send_sync( + "executed", + { + "node": current.node_id, + "display_node": current.node_id, + "output": preview_output, + "prompt_id": current.prompt_id, + }, + server.client_id, + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._stream.get_dimensions() + + def do_reset(self) -> None: + self._stream.reset() + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._stream.pull(max_frames) + self._emit_preview(chunk) + return chunk + + class ImageBatchToStream(io.ComfyNode): @classmethod def define_schema(cls): @@ -137,6 +178,29 @@ class ImageStreamToBatch(io.ComfyNode): return io.NodeOutput(torch.cat(chunks, dim=0)) +class PreviewImageStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PreviewImageStream", + display_name="Preview Image Stream", + category="image/stream", + search_aliases=["stream preview", "preview frames", "preview image stream"], + description="Passes an IMAGE_STREAM through while previewing the last frame from each pulled chunk.", + has_intermediate_output=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to preview inline."), + ], + outputs=[ + io.ImageStream.Output(display_name="passthrough"), + ], + ) + + @classmethod + def execute(cls, stream: Input.ImageStream) -> io.NodeOutput: + return io.NodeOutput(PreviewingImageStream(stream)) + + class StreamSink(io.ComfyNode): @classmethod def define_schema(cls): @@ -166,6 +230,7 @@ class ImageStreamExtension(ComfyExtension): return [ ImageBatchToStream, ImageStreamToBatch, + PreviewImageStream, StreamSink, ] From 3d21d2afb6eaa04aa7beae3648793329ad5ddfdb Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 13 Apr 2026 22:34:51 +1000 Subject: [PATCH 07/17] nodes_video: implement stream saver nodes --- comfy_extras/nodes_video.py | 242 +++++++++++++++++++++++++++++++++++- 1 file changed, 241 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 5c096c232..f04653d28 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,11 +5,162 @@ import av import torch import folder_paths import json -from typing import Optional +from typing import Callable, Optional from typing_extensions import override from fractions import Fraction from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args +from comfy_execution.utils import get_executing_context +from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream +from server import PromptServer + + +class SavedVideoStream(Input.ImageStream): + def __init__( + self, + stream: Input.ImageStream, + saver, + output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]], + format: str, + codec, + metadata: Optional[dict], + emit_preview_on_finalize: bool = True, + preview_node_id: Optional[str] = None, + preview_display_node_id: Optional[str] = None, + ): + super().__init__() + self._stream = stream + self._saver = saver + self._output_factory = output_factory + self._path: str | None = None + self._format = Types.VideoContainer(format) + self._codec = Types.VideoCodec(codec) + self._metadata = metadata + self._preview_ui: ui.PreviewVideo | None = None + self._emit_preview_on_finalize = emit_preview_on_finalize + self._preview_node_id = preview_node_id + self._preview_display_node_id = preview_display_node_id + self._save_state = None + + def _emit_preview(self) -> None: + if not self._emit_preview_on_finalize or self._preview_ui is None: + return + + current = get_executing_context() + if current is None: + return + + server = getattr(PromptServer, "instance", None) + if server is None or server.client_id is None: + return + + server.send_sync( + "executed", + { + "node": self._preview_node_id or current.node_id, + "display_node": self._preview_display_node_id or self._preview_node_id or current.node_id, + "output": self._preview_ui.as_dict(), + "prompt_id": current.prompt_id, + }, + server.client_id, + ) + + def _discard_partial_output(self) -> None: + if self._save_state is not None: + self._save_state[0].close() + self._save_state = None + if self._path is not None and os.path.exists(self._path): + os.remove(self._path) + self._path = None + self._preview_ui = None + + def get_preview_ui(self) -> ui.PreviewVideo | None: + return self._preview_ui + + def get_dimensions(self) -> tuple[int, int]: + return self._stream.get_dimensions() + + def do_reset(self) -> None: + self._discard_partial_output() + self._stream.reset() + self._path, self._preview_ui = self._output_factory() + assert self._path is not None + open(self._path, "ab").close() + self._save_state = self._saver.save_start( + self._path, + format=self._format, + codec=self._codec, + metadata=self._metadata, + ) + + def do_pull(self, max_frames: int) -> Input.Image: + assert self._save_state is not None + chunk = self._stream.pull(max_frames) + self._saver.save_add(self._save_state[0], self._save_state[1], chunk) + if chunk.shape[0] < max_frames: + self._saver.save_finalize(*self._save_state) + self._save_state = None + self._emit_preview() + self._path = None + return chunk + + +def _build_saved_stream( + hidden, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + filename_prefix, + format: str, + codec, + emit_preview: bool = True, +) -> SavedVideoStream: + width, height = stream.get_dimensions() + saved_metadata = None + if not args.disable_metadata: + metadata = {} + if hidden.extra_pnginfo is not None: + metadata.update(hidden.extra_pnginfo) + if hidden.prompt is not None: + metadata["prompt"] = hidden.prompt + if len(metadata) > 0: + saved_metadata = metadata + + preview_node_id = hidden.unique_id + preview_display_node_id = preview_node_id + if hidden.dynprompt is not None and preview_node_id is not None: + preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id) + + def output_factory() -> tuple[str, ui.PreviewVideo]: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, + folder_paths.get_output_directory(), + width, + height, + ) + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" + return ( + os.path.join(full_output_folder, file), + ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]), + ) + + return SavedVideoStream( + stream, + InputImpl.VideoFromComponents( + Types.VideoComponents( + images=torch.zeros((0, height, width, 3)), + audio=audio, + frame_rate=Fraction(fps), + ) + ), + output_factory, + format, + codec, + saved_metadata, + emit_preview, + preview_node_id, + preview_display_node_id, + ) class SaveWEBM(io.ComfyNode): @classmethod @@ -114,6 +265,93 @@ class SaveVideo(io.ComfyNode): return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) +class SavePassthroughVideoStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SavePassthroughVideoStream", + search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"], + display_name="Save+Passthrough Video Stream", + category="image/video", + essentials_category="Basics", + description="Saves frames as they pass through the input image stream.", + has_intermediate_output=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to save."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[ + io.ImageStream.Output(display_name="passthrough"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], + ) + + @classmethod + def execute( + cls, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + filename_prefix, + format: str, + codec, + ) -> io.NodeOutput: + return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec)) + + +class SaveVideoStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideoStream", + search_aliases=["save image stream", "export video stream", "stream to video"], + display_name="Save Video Stream", + category="image/video", + essentials_category="Basics", + description="Saves an image stream by draining it directly to EOF.", + is_output_node=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to save."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.Int.Input("chunk_size", default=8, min=1, max=4096), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], + ) + + @classmethod + def execute( + cls, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + chunk_size: int, + filename_prefix, + format: str, + codec, + ) -> io.NodeOutput: + saved_stream = _build_saved_stream( + cls.hidden, + stream, + audio, + fps, + filename_prefix, + format, + codec, + emit_preview=False, + ) + drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker()) + return io.NodeOutput(ui=saved_stream.get_preview_ui()) + + class CreateVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -262,6 +500,8 @@ class VideoExtension(ComfyExtension): return [ SaveWEBM, SaveVideo, + SavePassthroughVideoStream, + SaveVideoStream, CreateVideo, GetVideoComponents, LoadVideo, From 9dede56facdbdf5698d0e81c8024a00b01070103 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 08:45:18 +1000 Subject: [PATCH 08/17] implementation plan --- .../vae/causal_video_autoencoder.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 998122c85..539c0487c 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -553,6 +553,10 @@ class Decoder(nn.Module): t = sample.shape[2] output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) output_offset[0] += t + #if there isnt space in the output buffer, you need to stash unconsumed + #frames in the new state as its own thing. Forward resume then just immedatiately + #copies those into the next slice of output. Do not clone. Just slice. + #The VRAM is not a big deal. return up_block = self.up_blocks[idx] @@ -575,11 +579,17 @@ class Decoder(nn.Module): # when we are not chunking, detach our x so the callee can free it as soon as they are done next_sample_ref = [sample] del sample + #Just let this run_up unconditionally regardless of, its ok because either a lower layer + #chunker or output frame stash will do the work anyway. so unchanged. self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) return else: samples = torch.chunk(sample, chunks=num_chunks, dim=2) + #This loop might need to become a while. + #If the output buffer is exhausted (or none), it need to stash whatever is left of the samples + #list to new state. + #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. for chunk_idx, sample1 in enumerate(samples): self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) @@ -630,6 +640,8 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) + #The meaning of output_buffer == None changes. + #forward now needs to do this for the non start-resume case. if output_buffer is None: output_buffer = torch.empty( self.decode_output_shape(sample.shape), @@ -643,6 +655,24 @@ class Decoder(nn.Module): return output_buffer + def forward_start(self, *args, **kwargs): + #output_buffer == None implies initial exhaustion, so this should park it in a + #resumable state on the bottom of the run_up stack. + raise NotImplementedError("Decoder.forward_start is not implemented yet") + + #completely new function (maybe) + def forward_resume(self, *args, **kwargs): + #your code here + #inspect the new state. Pop any complete frames first. + #Then execute run_up on the highest index frame. You will need to be the list iterator + #for chunked run_up stashes and potentially restash incompleted lists. + #come down the indicies as stash highest to lowest and just like run_up you need to + #return if output_buffer is exhausted with the extra run_up calls potentially leaving + #behind new state for the next time this is called. + #if we have truly finished, do the same logic as the finally below to clean up. + #we will design the protocol for signalling end to the caller later. + raise NotImplementedError("Decoder.forward_resume is not implemented yet") + def forward(self, *args, **kwargs): try: return self.forward_orig(*args, **kwargs) From ad91467ef6d5b2aecb5c0a2946276b2f63edac81 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:00:50 +1000 Subject: [PATCH 09/17] ltx: vae: Move decoder output buffer allocation back to forward. So none can mean none. --- .../vae/causal_video_autoencoder.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 539c0487c..e137ae28d 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -640,13 +640,6 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) - #The meaning of output_buffer == None changes. - #forward now needs to do this for the non start-resume case. - if output_buffer is None: - output_buffer = torch.empty( - self.decode_output_shape(sample.shape), - dtype=sample.dtype, device=comfy.model_management.intermediate_device(), - ) output_offset = [0] max_chunk_size = get_max_chunk_size(sample.device) @@ -673,9 +666,19 @@ class Decoder(nn.Module): #we will design the protocol for signalling end to the caller later. raise NotImplementedError("Decoder.forward_resume is not implemented yet") - def forward(self, *args, **kwargs): + def forward( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, + ): + if output_buffer is None: + output_buffer = torch.empty( + self.decode_output_shape(sample.shape), + dtype=sample.dtype, device=comfy.model_management.intermediate_device(), + ) try: - return self.forward_orig(*args, **kwargs) + return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer) finally: for _, module in self.named_modules(): #ComfyUI doesn't thread this kind of stuff today, but just incase From ab9e006873adea59131dac58ee0aef3ba53a5a03 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:08:28 +1000 Subject: [PATCH 10/17] ltx: vae: Move constants to a named tuple. Consolidate these into a named tuple. This will expand with more content. Save it to the Decoder module itself for reusability. --- .../vae/causal_video_autoencoder.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index e137ae28d..ab1990898 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -16,6 +16,12 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +class RunUpState: + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn): + self.timestep_shift_scale = timestep_shift_scale + self.scaled_timestep = scaled_timestep + self.checkpoint_fn = checkpoint_fn + def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -530,19 +536,20 @@ class Decoder(nn.Module): ).unsqueeze(1).expand(2, output_channel), persistent=False, ) + self.temporal_cache_state = {} def decode_output_shape(self, input_shape): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) - def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size): sample = sample_ref[0] sample_ref[0] = None if idx >= len(self.up_blocks): sample = self.conv_norm_out(sample) - if timestep_shift_scale is not None: - shift, scale = timestep_shift_scale + if run_up_state.timestep_shift_scale is not None: + shift, scale = run_up_state.timestep_shift_scale sample = sample * (1 + scale) + shift sample = self.conv_act(sample) if ended: @@ -563,11 +570,11 @@ class Decoder(nn.Module): if ended: mark_conv3d_ended(up_block) if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep + sample = run_up_state.checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=run_up_state.scaled_timestep ) else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) + sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal) if sample is None or sample.shape[2] == 0: return @@ -581,7 +588,7 @@ class Decoder(nn.Module): del sample #Just let this run_up unconditionally regardless of, its ok because either a lower layer #chunker or output frame stash will do the work anyway. so unchanged. - self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size) return else: samples = torch.chunk(sample, chunks=num_chunks, dim=2) @@ -591,7 +598,7 @@ class Decoder(nn.Module): #list to new state. #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset, max_chunk_size) def forward_orig( self, @@ -643,8 +650,14 @@ class Decoder(nn.Module): output_offset = [0] max_chunk_size = get_max_chunk_size(sample.device) + run_up_state = RunUpState( + timestep_shift_scale=timestep_shift_scale, + scaled_timestep=scaled_timestep, + checkpoint_fn=checkpoint_fn, + ) + self.temporal_cache_state[threading.get_ident()] = run_up_state - self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset, max_chunk_size) return output_buffer From 930df2d70bcbf554e4e04f0fe652731d11d0e1eb Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:36:57 +1000 Subject: [PATCH 11/17] ltx: vae: save left over tail frames to state If it doesnt fit stash it. --- .../lightricks/vae/causal_video_autoencoder.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index ab1990898..209f35758 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -17,10 +17,11 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init class RunUpState: - def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn): + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_frames=None): self.timestep_shift_scale = timestep_shift_scale self.scaled_timestep = scaled_timestep self.checkpoint_fn = checkpoint_fn + self.output_frames = output_frames def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -557,13 +558,15 @@ class Decoder(nn.Module): sample = self.conv_out(sample, causal=self.causal) if sample is not None and sample.shape[2] > 0: sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - t = sample.shape[2] - output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + if output_buffer is None: + run_up_state.output_frames = sample + return + output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]] + t = output_slice.shape[2] + output_slice.copy_(sample[:, :, :t]) output_offset[0] += t - #if there isnt space in the output buffer, you need to stash unconsumed - #frames in the new state as its own thing. Forward resume then just immedatiately - #copies those into the next slice of output. Do not clone. Just slice. - #The VRAM is not a big deal. + if t < sample.shape[2]: + run_up_state.output_frames = sample[:, :, t:] return up_block = self.up_blocks[idx] From ce054bbf2d30fa7f462b8904c39e3389fbaeee32 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:48:51 +1000 Subject: [PATCH 12/17] ltx: vae: move max_chunk_size to the RunUpState --- .../lightricks/vae/causal_video_autoencoder.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 209f35758..0b2d0edcf 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -17,10 +17,11 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init class RunUpState: - def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_frames=None): + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None): self.timestep_shift_scale = timestep_shift_scale self.scaled_timestep = scaled_timestep self.checkpoint_fn = checkpoint_fn + self.max_chunk_size = max_chunk_size self.output_frames = output_frames def in_meta_context(): @@ -544,7 +545,7 @@ class Decoder(nn.Module): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) - def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size): + def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset): sample = sample_ref[0] sample_ref[0] = None if idx >= len(self.up_blocks): @@ -583,7 +584,7 @@ class Decoder(nn.Module): return total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size if num_chunks == 1: # when we are not chunking, detach our x so the callee can free it as soon as they are done @@ -591,7 +592,7 @@ class Decoder(nn.Module): del sample #Just let this run_up unconditionally regardless of, its ok because either a lower layer #chunker or output frame stash will do the work anyway. so unchanged. - self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset) return else: samples = torch.chunk(sample, chunks=num_chunks, dim=2) @@ -601,7 +602,7 @@ class Decoder(nn.Module): #list to new state. #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset) def forward_orig( self, @@ -652,15 +653,15 @@ class Decoder(nn.Module): output_offset = [0] - max_chunk_size = get_max_chunk_size(sample.device) run_up_state = RunUpState( timestep_shift_scale=timestep_shift_scale, scaled_timestep=scaled_timestep, checkpoint_fn=checkpoint_fn, + max_chunk_size=get_max_chunk_size(sample.device), ) self.temporal_cache_state[threading.get_ident()] = run_up_state - self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset, max_chunk_size) + self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset) return output_buffer From 4ba0b9125d8da56b006475d9c39d296acbd733ec Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 10:09:43 +1000 Subject: [PATCH 13/17] ltx: vae: save un-actionable chunks to RunUpState. --- .../ldm/lightricks/vae/causal_video_autoencoder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 0b2d0edcf..55be72e6f 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -23,6 +23,7 @@ class RunUpState: self.checkpoint_fn = checkpoint_fn self.max_chunk_size = max_chunk_size self.output_frames = output_frames + self.pending_samples = [] def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -595,14 +596,13 @@ class Decoder(nn.Module): self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset) return else: - samples = torch.chunk(sample, chunks=num_chunks, dim=2) + samples = list(torch.chunk(sample, chunks=num_chunks, dim=2)) - #This loop might need to become a while. - #If the output buffer is exhausted (or none), it need to stash whatever is left of the samples - #list to new state. - #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. - for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset) + while len(samples): + if output_buffer is None or output_offset[0] == output_buffer.shape[2]: + run_up_state.pending_samples.append((idx + 1, samples, ended)) + return + self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset) def forward_orig( self, From 06381d5d1880e23ebed4d7e29ee2a65ef0198f51 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 10:16:17 +1000 Subject: [PATCH 14/17] ltx: vae: consolidate cache clearer function --- .../vae/causal_video_autoencoder.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 55be72e6f..13c1c6a1b 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -35,6 +35,14 @@ def mark_conv3d_ended(module): current = m.temporal_cache_state.get(tid, (None, False)) m.temporal_cache_state[tid] = (current[0], True) +def clear_temporal_cache_state(module): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + for _, m in module.named_modules(): + if hasattr(m, "temporal_cache_state"): + m.temporal_cache_state.pop(tid, None) + def split2(tensor, split_point, dim=2): return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) @@ -324,13 +332,7 @@ class Encoder(nn.Module): try: return self.forward_orig(*args, **kwargs) finally: - tid = threading.get_ident() - for _, module in self.named_modules(): - # ComfyUI doesn't thread this kind of stuff today, but just in case - # we key on the thread to make it thread safe. - tid = threading.get_ident() - if hasattr(module, "temporal_cache_state"): - module.temporal_cache_state.pop(tid, None) + clear_temporal_cache_state(self) MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 @@ -697,12 +699,7 @@ class Decoder(nn.Module): try: return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer) finally: - for _, module in self.named_modules(): - #ComfyUI doesn't thread this kind of stuff today, but just incase - #we key on the thread to make it thread safe. - tid = threading.get_ident() - if hasattr(module, "temporal_cache_state"): - module.temporal_cache_state.pop(tid, None) + clear_temporal_cache_state(self) class UNetMidBlock3D(nn.Module): From b23f1f456fbaa2ccddb3812b3d7a59f36ae54d7a Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 16:14:52 +1000 Subject: [PATCH 15/17] ltx: vae: implement start and resume protocol --- .../vae/causal_video_autoencoder.py | 73 ++++++++++++++----- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 13c1c6a1b..c0211addd 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -17,11 +17,13 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init class RunUpState: - def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None): + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None): self.timestep_shift_scale = timestep_shift_scale self.scaled_timestep = scaled_timestep self.checkpoint_fn = checkpoint_fn self.max_chunk_size = max_chunk_size + self.output_shape = output_shape + self.output_dtype = output_dtype self.output_frames = output_frames self.pending_samples = [] @@ -614,6 +616,7 @@ class Decoder(nn.Module): ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + output_shape = self.decode_output_shape(sample.shape) mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) @@ -660,6 +663,8 @@ class Decoder(nn.Module): scaled_timestep=scaled_timestep, checkpoint_fn=checkpoint_fn, max_chunk_size=get_max_chunk_size(sample.device), + output_shape=output_shape, + output_dtype=sample.dtype, ) self.temporal_cache_state[threading.get_ident()] = run_up_state @@ -667,23 +672,57 @@ class Decoder(nn.Module): return output_buffer - def forward_start(self, *args, **kwargs): - #output_buffer == None implies initial exhaustion, so this should park it in a - #resumable state on the bottom of the run_up stack. - raise NotImplementedError("Decoder.forward_start is not implemented yet") + def forward_start( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + ): + try: + return self.forward_orig(sample, timestep=timestep, output_buffer=None) + except Exception: + clear_temporal_cache_state(self) + raise - #completely new function (maybe) - def forward_resume(self, *args, **kwargs): - #your code here - #inspect the new state. Pop any complete frames first. - #Then execute run_up on the highest index frame. You will need to be the list iterator - #for chunked run_up stashes and potentially restash incompleted lists. - #come down the indicies as stash highest to lowest and just like run_up you need to - #return if output_buffer is exhausted with the extra run_up calls potentially leaving - #behind new state for the next time this is called. - #if we have truly finished, do the same logic as the finally below to clean up. - #we will design the protocol for signalling end to the caller later. - raise NotImplementedError("Decoder.forward_resume is not implemented yet") + def forward_resume(self, output_t: int): + tid = threading.get_ident() + run_up_state = self.temporal_cache_state.get(tid, None) + if run_up_state is None: + return None + + output_shape = list(run_up_state.output_shape) + output_shape[2] = output_t + output_buffer = torch.empty( + output_shape, + dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] + + try: + if run_up_state.output_frames is not None: + output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]] + t = output_slice.shape[2] + output_slice.copy_(run_up_state.output_frames[:, :, :t]) + output_offset[0] += t + run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:] + + pending_samples = run_up_state.pending_samples + run_up_state.pending_samples = [] + while len(pending_samples): + idx, samples, ended = pending_samples.pop(0) + while len(samples): + if output_offset[0] == output_buffer.shape[2]: + pending_samples = [(idx, samples, ended)] + pending_samples + run_up_state.pending_samples.extend(pending_samples) + return output_buffer + sample1 = samples.pop(0) + self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset) + + if run_up_state.output_frames is None and not run_up_state.pending_samples: + clear_temporal_cache_state(self) + return output_buffer[:, :, :output_offset[0]] + except Exception: + clear_temporal_cache_state(self) + raise def forward( self, From 0c70446c9b0b7afc29ff198e2f8ea0a3f455c469 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 14 Apr 2026 11:12:47 +1000 Subject: [PATCH 16/17] nodes_video: save stream: Make audio optional --- comfy_extras/nodes_video.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index f04653d28..92fce02cf 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -278,11 +278,11 @@ class SavePassthroughVideoStream(io.ComfyNode): has_intermediate_output=True, inputs=[ io.ImageStream.Input("stream", tooltip="The image stream to save."), - io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), ], outputs=[ io.ImageStream.Output(display_name="passthrough"), @@ -294,11 +294,11 @@ class SavePassthroughVideoStream(io.ComfyNode): def execute( cls, stream: Input.ImageStream, - audio: Optional[Input.Audio], fps: float, filename_prefix, format: str, codec, + audio: Optional[Input.Audio] = None, ) -> io.NodeOutput: return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec)) @@ -316,12 +316,12 @@ class SaveVideoStream(io.ComfyNode): is_output_node=True, inputs=[ io.ImageStream.Input("stream", tooltip="The image stream to save."), - io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.Int.Input("chunk_size", default=8, min=1, max=4096), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), ], outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], @@ -331,12 +331,12 @@ class SaveVideoStream(io.ComfyNode): def execute( cls, stream: Input.ImageStream, - audio: Optional[Input.Audio], fps: float, chunk_size: int, filename_prefix, format: str, codec, + audio: Optional[Input.Audio] = None, ) -> io.NodeOutput: saved_stream = _build_saved_stream( cls.hidden, From 1c2d37944c61ef8d421f8a782f3edb80c392dc9f Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 14 Apr 2026 11:15:47 +1000 Subject: [PATCH 17/17] nodes_image_stream: implement VAE decoder node --- .../vae/causal_video_autoencoder.py | 9 +++ comfy/sd.py | 11 +++ comfy_extras/nodes_image_stream.py | 81 +++++++++++++++++++ 3 files changed, 101 insertions(+) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index c0211addd..91606ffa6 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1388,6 +1388,15 @@ class VideoVAE(nn.Module): def decode_output_shape(self, input_shape): return self.decoder.decode_output_shape(input_shape) + def decode_start(self, x): + clear_temporal_cache_state(self.decoder) + if self.timestep_conditioning: #TODO: seed + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) + + def decode_chunk(self, output_t: int): + return self.decoder.forward_resume(output_t) + def decode(self, x, output_buffer=None): if self.timestep_conditioning: #TODO: seed x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..9e442772e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -995,6 +995,17 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples + def decode_output_shape(self, samples_shape): + self.throw_exception_if_invalid() + if hasattr(self.first_stage_model, "decode_output_shape"): + return self.first_stage_model.decode_output_shape(samples_shape) + raise RuntimeError("This VAE does not expose decode output shape information.") + + def decode_stream_start(self, samples_in): + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype)) + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile diff --git a/comfy_extras/nodes_image_stream.py b/comfy_extras/nodes_image_stream.py index fd2ac6915..f01075887 100644 --- a/comfy_extras/nodes_image_stream.py +++ b/comfy_extras/nodes_image_stream.py @@ -123,6 +123,60 @@ class PreviewingImageStream(Input.ImageStream): return chunk +class VAEDecodedImageStream(Input.ImageStream): + def __init__(self, vae, latent: Input.Latent): + super().__init__() + self._vae = vae + self._latent = latent + vae.throw_exception_if_invalid() + if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False): + raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.") + if latent.ndim != 5: + raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].") + if latent.shape[0] != 1: + raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.") + output_shape = vae.decode_output_shape(latent.shape) + self._channels = int(output_shape[1]) + self._width = int(output_shape[4]) + self._height = int(output_shape[3]) + self._total_frames = int(output_shape[0] * output_shape[2]) + self._frames_emitted = 0 + + def _update_progress(self, value: int) -> None: + current = get_executing_context() + if current is None: + return + get_progress_state().update_progress( + current.node_id, + value=float(value), + max_value=float(max(self._total_frames, 1)), + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._width, self._height + + def do_reset(self) -> None: + self._frames_emitted = 0 + self._update_progress(0) + self._vae.decode_stream_start(self._latent) + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._vae.first_stage_model.decode_chunk(max_frames) + if chunk is None: + return torch.empty( + (0, self._height, self._width, self._channels), + device=self._vae.output_device, + dtype=self._vae.vae_output_dtype(), + ) + + chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype()) + chunk = self._vae.process_output(chunk).movedim(1, -1) + chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:])) + self._frames_emitted += int(chunk.shape[0]) + self._update_progress(self._frames_emitted) + return chunk + + class ImageBatchToStream(io.ComfyNode): @classmethod def define_schema(cls): @@ -178,6 +232,32 @@ class ImageStreamToBatch(io.ComfyNode): return io.NodeOutput(torch.cat(chunks, dim=0)) +class VAEDecodeStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VAEDecodeStream", + display_name="VAE Decode Stream", + category="image/stream", + search_aliases=["vae stream decode", "latent to stream", "video latent stream"], + description="Decodes a latent into an IMAGE_STREAM.", + inputs=[ + io.Latent.Input("samples", tooltip="The LTX latent to decode."), + io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."), + ], + outputs=[ + io.ImageStream.Output(display_name="stream"), + ], + ) + + @classmethod + def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput: + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + return io.NodeOutput(VAEDecodedImageStream(vae, latent)) + + class PreviewImageStream(io.ComfyNode): @classmethod def define_schema(cls): @@ -230,6 +310,7 @@ class ImageStreamExtension(ComfyExtension): return [ ImageBatchToStream, ImageStreamToBatch, + VAEDecodeStream, PreviewImageStream, StreamSink, ]