ComfyUI/comfy_extras/nodes_image_stream.py
Rattus 7b1d4bcdf6 node_image_stream: add
Add some nodes getting into and out of stream mode.
2026-04-14 11:18:07 +10:00

175 lines
5.4 KiB
Python

from __future__ import annotations
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, io, ui
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from server import PromptServer
class FrameProgressTracker:
def __init__(self):
self._last_reported=None
def emit(self, frames_processed):
if frames_processed == self._last_reported:
return
current = get_executing_context()
server = getattr(PromptServer, "instance", None)
if current is None or server is None or server.client_id is None:
return
server.send_progress_text(
f"processed {frames_processed} frames",
current.node_id,
server.client_id,
)
self._last_reported = frames_processed
def drain_image_stream(stream, chunk_size, progress):
stream.reset()
frames_processed = 0
while True:
progress.emit(frames_processed)
chunk = stream.pull(chunk_size)
frames_processed += int(chunk.shape[0])
if chunk.shape[0] < chunk_size:
progress.emit(frames_processed)
return frames_processed
class TensorImageStream(Input.ImageStream):
"""Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor."""
def __init__(self, images: Input.Image):
super().__init__()
self._images = images
self._index = 0
self._total_frames = int(images.shape[0])
self._progress_started = False
def _update_progress(self, value: int) -> None:
current = get_executing_context()
if current is None:
return
get_progress_state().update_progress(
current.node_id,
value=float(value),
max_value=float(max(self._total_frames, 1)),
)
def get_dimensions(self) -> tuple[int, int]:
return self._images.shape[2], self._images.shape[1]
def do_reset(self) -> None:
self._index = 0
self._progress_started = False
def do_pull(self, max_frames: int) -> Input.Image:
if not self._progress_started:
self._update_progress(0)
self._progress_started = True
start = self._index
end = min(start + max_frames, self._images.shape[0])
self._index = end
chunk = self._images[start:end].clone()
self._update_progress(end)
return chunk
class ImageBatchToStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageBatchToStream",
display_name="Image Batch To Stream",
category="image/stream",
search_aliases=["image to stream", "batch to stream", "frames to stream"],
description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.",
inputs=[
io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."),
],
outputs=[
io.ImageStream.Output(display_name="stream"),
],
)
@classmethod
def execute(cls, image: Input.Image) -> io.NodeOutput:
return io.NodeOutput(TensorImageStream(image))
class ImageStreamToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageStreamToBatch",
display_name="Image Stream To Batch",
category="image/stream",
search_aliases=["stream to image", "stream to batch", "collect stream"],
description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.",
inputs=[
io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."),
io.Int.Input("batch_size", default=4096, min=1, max=4096),
],
outputs=[
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput:
chunks: list[Input.Image] = []
stream.reset()
while True:
chunk = stream.pull(batch_size)
chunks.append(chunk)
if chunk.shape[0] < batch_size:
break
return io.NodeOutput(torch.cat(chunks, dim=0))
class StreamSink(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StreamSink",
search_aliases=["consume stream", "drain stream", "image stream sink"],
display_name="Stream Sink",
category="image/stream",
description="Consumes an IMAGE_STREAM by pulling it to EOF.",
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to consume."),
io.Int.Input("chunk_size", default=8, min=1, max=4096),
],
outputs=[],
is_output_node=True,
)
@classmethod
def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput:
drain_image_stream(stream, chunk_size, progress=FrameProgressTracker())
return io.NodeOutput()
class ImageStreamExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageBatchToStream,
ImageStreamToBatch,
StreamSink,
]
async def comfy_entrypoint() -> ImageStreamExtension:
return ImageStreamExtension()