from __future__ import annotations import torch from typing_extensions import override from comfy_api.latest import ComfyExtension, Input, io, ui from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from server import PromptServer class FrameProgressTracker: def __init__(self): self._last_reported=None def emit(self, frames_processed): if frames_processed == self._last_reported: return current = get_executing_context() server = getattr(PromptServer, "instance", None) if current is None or server is None or server.client_id is None: return server.send_progress_text( f"processed {frames_processed} frames", current.node_id, server.client_id, ) self._last_reported = frames_processed def drain_image_stream(stream, chunk_size, progress): stream.reset() frames_processed = 0 while True: progress.emit(frames_processed) chunk = stream.pull(chunk_size) frames_processed += int(chunk.shape[0]) if chunk.shape[0] < chunk_size: progress.emit(frames_processed) return frames_processed class TensorImageStream(Input.ImageStream): """Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor.""" def __init__(self, images: Input.Image): super().__init__() self._images = images self._index = 0 self._total_frames = int(images.shape[0]) self._progress_started = False def _update_progress(self, value: int) -> None: current = get_executing_context() if current is None: return get_progress_state().update_progress( current.node_id, value=float(value), max_value=float(max(self._total_frames, 1)), ) def get_dimensions(self) -> tuple[int, int]: return self._images.shape[2], self._images.shape[1] def do_reset(self) -> None: self._index = 0 self._progress_started = False def do_pull(self, max_frames: int) -> Input.Image: if not self._progress_started: self._update_progress(0) self._progress_started = True start = self._index end = min(start + max_frames, self._images.shape[0]) self._index = end chunk = self._images[start:end].clone() self._update_progress(end) return chunk class ImageBatchToStream(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ImageBatchToStream", display_name="Image Batch To Stream", category="image/stream", search_aliases=["image to stream", "batch to stream", "frames to stream"], description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.", inputs=[ io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."), ], outputs=[ io.ImageStream.Output(display_name="stream"), ], ) @classmethod def execute(cls, image: Input.Image) -> io.NodeOutput: return io.NodeOutput(TensorImageStream(image)) class ImageStreamToBatch(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ImageStreamToBatch", display_name="Image Stream To Batch", category="image/stream", search_aliases=["stream to image", "stream to batch", "collect stream"], description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.", inputs=[ io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."), io.Int.Input("batch_size", default=4096, min=1, max=4096), ], outputs=[ io.Image.Output(display_name="image"), ], ) @classmethod def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput: chunks: list[Input.Image] = [] stream.reset() while True: chunk = stream.pull(batch_size) chunks.append(chunk) if chunk.shape[0] < batch_size: break return io.NodeOutput(torch.cat(chunks, dim=0)) class StreamSink(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="StreamSink", search_aliases=["consume stream", "drain stream", "image stream sink"], display_name="Stream Sink", category="image/stream", description="Consumes an IMAGE_STREAM by pulling it to EOF.", inputs=[ io.ImageStream.Input("stream", tooltip="The image stream to consume."), io.Int.Input("chunk_size", default=8, min=1, max=4096), ], outputs=[], is_output_node=True, ) @classmethod def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput: drain_image_stream(stream, chunk_size, progress=FrameProgressTracker()) return io.NodeOutput() class ImageStreamExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ ImageBatchToStream, ImageStreamToBatch, StreamSink, ] async def comfy_entrypoint() -> ImageStreamExtension: return ImageStreamExtension()