ComfyUI/comfy_extras/nodes_image_stream.py

321 lines
11 KiB
Python

from __future__ import annotations
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, io, ui
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from server import PromptServer
class FrameProgressTracker:
def __init__(self):
self._last_reported=None
def emit(self, frames_processed):
if frames_processed == self._last_reported:
return
current = get_executing_context()
server = getattr(PromptServer, "instance", None)
if current is None or server is None or server.client_id is None:
return
server.send_progress_text(
f"processed {frames_processed} frames",
current.node_id,
server.client_id,
)
self._last_reported = frames_processed
def drain_image_stream(stream, chunk_size, progress):
stream.reset()
frames_processed = 0
while True:
progress.emit(frames_processed)
chunk = stream.pull(chunk_size)
frames_processed += int(chunk.shape[0])
if chunk.shape[0] < chunk_size:
progress.emit(frames_processed)
return frames_processed
class TensorImageStream(Input.ImageStream):
"""Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor."""
def __init__(self, images: Input.Image):
super().__init__()
self._images = images
self._index = 0
self._total_frames = int(images.shape[0])
self._progress_started = False
def _update_progress(self, value: int) -> None:
current = get_executing_context()
if current is None:
return
get_progress_state().update_progress(
current.node_id,
value=float(value),
max_value=float(max(self._total_frames, 1)),
)
def get_dimensions(self) -> tuple[int, int]:
return self._images.shape[2], self._images.shape[1]
def do_reset(self) -> None:
self._index = 0
self._progress_started = False
def do_pull(self, max_frames: int) -> Input.Image:
if not self._progress_started:
self._update_progress(0)
self._progress_started = True
start = self._index
end = min(start + max_frames, self._images.shape[0])
self._index = end
chunk = self._images[start:end].clone()
self._update_progress(end)
return chunk
class PreviewingImageStream(Input.ImageStream):
def __init__(self, stream: Input.ImageStream):
super().__init__()
self._stream = stream
def _emit_preview(self, chunk: Input.Image) -> None:
if int(chunk.shape[0]) == 0:
return
current = get_executing_context()
if current is None:
return
server = getattr(PromptServer, "instance", None)
if server is None or server.client_id is None:
return
preview_output = ui.PreviewImage(chunk[-1:]).as_dict()
server.send_sync(
"executed",
{
"node": current.node_id,
"display_node": current.node_id,
"output": preview_output,
"prompt_id": current.prompt_id,
},
server.client_id,
)
def get_dimensions(self) -> tuple[int, int]:
return self._stream.get_dimensions()
def do_reset(self) -> None:
self._stream.reset()
def do_pull(self, max_frames: int) -> Input.Image:
chunk = self._stream.pull(max_frames)
self._emit_preview(chunk)
return chunk
class VAEDecodedImageStream(Input.ImageStream):
def __init__(self, vae, latent: Input.Latent):
super().__init__()
self._vae = vae
self._latent = latent
vae.throw_exception_if_invalid()
if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False):
raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.")
if latent.ndim != 5:
raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].")
if latent.shape[0] != 1:
raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.")
output_shape = vae.decode_output_shape(latent.shape)
self._channels = int(output_shape[1])
self._width = int(output_shape[4])
self._height = int(output_shape[3])
self._total_frames = int(output_shape[0] * output_shape[2])
self._frames_emitted = 0
def _update_progress(self, value: int) -> None:
current = get_executing_context()
if current is None:
return
get_progress_state().update_progress(
current.node_id,
value=float(value),
max_value=float(max(self._total_frames, 1)),
)
def get_dimensions(self) -> tuple[int, int]:
return self._width, self._height
def do_reset(self) -> None:
self._frames_emitted = 0
self._update_progress(0)
self._vae.decode_stream_start(self._latent)
def do_pull(self, max_frames: int) -> Input.Image:
chunk = self._vae.first_stage_model.decode_chunk(max_frames)
if chunk is None:
return torch.empty(
(0, self._height, self._width, self._channels),
device=self._vae.output_device,
dtype=self._vae.vae_output_dtype(),
)
chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype())
chunk = self._vae.process_output(chunk).movedim(1, -1)
chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:]))
self._frames_emitted += int(chunk.shape[0])
self._update_progress(self._frames_emitted)
return chunk
class ImageBatchToStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageBatchToStream",
display_name="Image Batch To Stream",
category="image/stream",
search_aliases=["image to stream", "batch to stream", "frames to stream"],
description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.",
inputs=[
io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."),
],
outputs=[
io.ImageStream.Output(display_name="stream"),
],
)
@classmethod
def execute(cls, image: Input.Image) -> io.NodeOutput:
return io.NodeOutput(TensorImageStream(image))
class ImageStreamToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageStreamToBatch",
display_name="Image Stream To Batch",
category="image/stream",
search_aliases=["stream to image", "stream to batch", "collect stream"],
description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.",
inputs=[
io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."),
io.Int.Input("batch_size", default=4096, min=1, max=4096),
],
outputs=[
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput:
chunks: list[Input.Image] = []
stream.reset()
while True:
chunk = stream.pull(batch_size)
chunks.append(chunk)
if chunk.shape[0] < batch_size:
break
return io.NodeOutput(torch.cat(chunks, dim=0))
class VAEDecodeStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAEDecodeStream",
display_name="VAE Decode Stream",
category="image/stream",
search_aliases=["vae stream decode", "latent to stream", "video latent stream"],
description="Decodes a latent into an IMAGE_STREAM.",
inputs=[
io.Latent.Input("samples", tooltip="The LTX latent to decode."),
io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."),
],
outputs=[
io.ImageStream.Output(display_name="stream"),
],
)
@classmethod
def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput:
latent = samples["samples"]
if latent.is_nested:
latent = latent.unbind()[0]
return io.NodeOutput(VAEDecodedImageStream(vae, latent))
class PreviewImageStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PreviewImageStream",
display_name="Preview Image Stream",
category="image/stream",
search_aliases=["stream preview", "preview frames", "preview image stream"],
description="Passes an IMAGE_STREAM through while previewing the last frame from each pulled chunk.",
has_intermediate_output=True,
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to preview inline."),
],
outputs=[
io.ImageStream.Output(display_name="passthrough"),
],
)
@classmethod
def execute(cls, stream: Input.ImageStream) -> io.NodeOutput:
return io.NodeOutput(PreviewingImageStream(stream))
class StreamSink(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StreamSink",
search_aliases=["consume stream", "drain stream", "image stream sink"],
display_name="Stream Sink",
category="image/stream",
description="Consumes an IMAGE_STREAM by pulling it to EOF.",
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to consume."),
io.Int.Input("chunk_size", default=8, min=1, max=4096),
],
outputs=[],
is_output_node=True,
)
@classmethod
def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput:
drain_image_stream(stream, chunk_size, progress=FrameProgressTracker())
return io.NodeOutput()
class ImageStreamExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageBatchToStream,
ImageStreamToBatch,
VAEDecodeStream,
PreviewImageStream,
StreamSink,
]
async def comfy_entrypoint() -> ImageStreamExtension:
return ImageStreamExtension()