mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
node_image_stream: add
Add some nodes getting into and out of stream mode.
This commit is contained in:
parent
7dc366adc7
commit
7b1d4bcdf6
174
comfy_extras/nodes_image_stream.py
Normal file
174
comfy_extras/nodes_image_stream.py
Normal file
@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, io, ui
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class FrameProgressTracker:
|
||||
def __init__(self):
|
||||
self._last_reported=None
|
||||
|
||||
def emit(self, frames_processed):
|
||||
if frames_processed == self._last_reported:
|
||||
return
|
||||
|
||||
current = get_executing_context()
|
||||
server = getattr(PromptServer, "instance", None)
|
||||
if current is None or server is None or server.client_id is None:
|
||||
return
|
||||
|
||||
server.send_progress_text(
|
||||
f"processed {frames_processed} frames",
|
||||
current.node_id,
|
||||
server.client_id,
|
||||
)
|
||||
self._last_reported = frames_processed
|
||||
|
||||
|
||||
def drain_image_stream(stream, chunk_size, progress):
|
||||
stream.reset()
|
||||
frames_processed = 0
|
||||
|
||||
while True:
|
||||
progress.emit(frames_processed)
|
||||
chunk = stream.pull(chunk_size)
|
||||
frames_processed += int(chunk.shape[0])
|
||||
if chunk.shape[0] < chunk_size:
|
||||
progress.emit(frames_processed)
|
||||
return frames_processed
|
||||
|
||||
class TensorImageStream(Input.ImageStream):
|
||||
"""Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor."""
|
||||
|
||||
def __init__(self, images: Input.Image):
|
||||
super().__init__()
|
||||
self._images = images
|
||||
self._index = 0
|
||||
self._total_frames = int(images.shape[0])
|
||||
self._progress_started = False
|
||||
|
||||
def _update_progress(self, value: int) -> None:
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
get_progress_state().update_progress(
|
||||
current.node_id,
|
||||
value=float(value),
|
||||
max_value=float(max(self._total_frames, 1)),
|
||||
)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._images.shape[2], self._images.shape[1]
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._index = 0
|
||||
self._progress_started = False
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
if not self._progress_started:
|
||||
self._update_progress(0)
|
||||
self._progress_started = True
|
||||
|
||||
start = self._index
|
||||
end = min(start + max_frames, self._images.shape[0])
|
||||
self._index = end
|
||||
chunk = self._images[start:end].clone()
|
||||
self._update_progress(end)
|
||||
return chunk
|
||||
|
||||
|
||||
class ImageBatchToStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageBatchToStream",
|
||||
display_name="Image Batch To Stream",
|
||||
category="image/stream",
|
||||
search_aliases=["image to stream", "batch to stream", "frames to stream"],
|
||||
description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.",
|
||||
inputs=[
|
||||
io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="stream"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: Input.Image) -> io.NodeOutput:
|
||||
return io.NodeOutput(TensorImageStream(image))
|
||||
|
||||
|
||||
class ImageStreamToBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageStreamToBatch",
|
||||
display_name="Image Stream To Batch",
|
||||
category="image/stream",
|
||||
search_aliases=["stream to image", "stream to batch", "collect stream"],
|
||||
description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.",
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."),
|
||||
io.Int.Input("batch_size", default=4096, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput:
|
||||
chunks: list[Input.Image] = []
|
||||
stream.reset()
|
||||
|
||||
while True:
|
||||
chunk = stream.pull(batch_size)
|
||||
|
||||
chunks.append(chunk)
|
||||
if chunk.shape[0] < batch_size:
|
||||
break
|
||||
|
||||
return io.NodeOutput(torch.cat(chunks, dim=0))
|
||||
|
||||
|
||||
class StreamSink(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StreamSink",
|
||||
search_aliases=["consume stream", "drain stream", "image stream sink"],
|
||||
display_name="Stream Sink",
|
||||
category="image/stream",
|
||||
description="Consumes an IMAGE_STREAM by pulling it to EOF.",
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to consume."),
|
||||
io.Int.Input("chunk_size", default=8, min=1, max=4096),
|
||||
],
|
||||
outputs=[],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput:
|
||||
drain_image_stream(stream, chunk_size, progress=FrameProgressTracker())
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
class ImageStreamExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ImageBatchToStream,
|
||||
ImageStreamToBatch,
|
||||
StreamSink,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ImageStreamExtension:
|
||||
return ImageStreamExtension()
|
||||
Loading…
Reference in New Issue
Block a user