This commit is contained in:
rattus 2026-04-23 15:54:59 +01:00 committed by GitHub
commit 2e0438bbd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 873 additions and 54 deletions

View File

@ -51,6 +51,7 @@ class IO(StrEnum):
BBOX = "BBOX"
SEGS = "SEGS"
VIDEO = "VIDEO"
IMAGE_STREAM = "IMAGE_STREAM"
ANY = "*"
"""Always matches any type, but at a price.

View File

@ -16,6 +16,17 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
class RunUpState:
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None):
self.timestep_shift_scale = timestep_shift_scale
self.scaled_timestep = scaled_timestep
self.checkpoint_fn = checkpoint_fn
self.max_chunk_size = max_chunk_size
self.output_shape = output_shape
self.output_dtype = output_dtype
self.output_frames = output_frames
self.pending_samples = []
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
@ -26,6 +37,14 @@ def mark_conv3d_ended(module):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def clear_temporal_cache_state(module):
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
for _, m in module.named_modules():
if hasattr(m, "temporal_cache_state"):
m.temporal_cache_state.pop(tid, None)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
@ -315,13 +334,7 @@ class Encoder(nn.Module):
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
clear_temporal_cache_state(self)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
@ -530,19 +543,20 @@ class Decoder(nn.Module):
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
self.temporal_cache_state = {}
def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
if run_up_state.timestep_shift_scale is not None:
shift, scale = run_up_state.timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
@ -550,38 +564,49 @@ class Decoder(nn.Module):
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
if output_buffer is None:
run_up_state.output_frames = sample
return
output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]]
t = output_slice.shape[2]
output_slice.copy_(sample[:, :, :t])
output_offset[0] += t
if t < sample.shape[2]:
run_up_state.output_frames = sample[:, :, t:]
return
up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
sample = run_up_state.checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=run_up_state.scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
#chunker or output frame stash will do the work anyway. so unchanged.
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
samples = list(torch.chunk(sample, chunks=num_chunks, dim=2))
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
while len(samples):
if output_buffer is None or output_offset[0] == output_buffer.shape[2]:
run_up_state.pending_samples.append((idx + 1, samples, ended))
return
self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset)
def forward_orig(
self,
@ -591,6 +616,7 @@ class Decoder(nn.Module):
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
output_shape = self.decode_output_shape(sample.shape)
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)
@ -630,29 +656,89 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
output_offset = [0]
run_up_state = RunUpState(
timestep_shift_scale=timestep_shift_scale,
scaled_timestep=scaled_timestep,
checkpoint_fn=checkpoint_fn,
max_chunk_size=get_max_chunk_size(sample.device),
output_shape=output_shape,
output_dtype=sample.dtype,
)
self.temporal_cache_state[threading.get_ident()] = run_up_state
self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset)
return output_buffer
def forward_start(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
):
try:
return self.forward_orig(sample, timestep=timestep, output_buffer=None)
except Exception:
clear_temporal_cache_state(self)
raise
def forward_resume(self, output_t: int):
tid = threading.get_ident()
run_up_state = self.temporal_cache_state.get(tid, None)
if run_up_state is None:
return None
output_shape = list(run_up_state.output_shape)
output_shape[2] = output_t
output_buffer = torch.empty(
output_shape,
dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
try:
if run_up_state.output_frames is not None:
output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]]
t = output_slice.shape[2]
output_slice.copy_(run_up_state.output_frames[:, :, :t])
output_offset[0] += t
run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:]
pending_samples = run_up_state.pending_samples
run_up_state.pending_samples = []
while len(pending_samples):
idx, samples, ended = pending_samples.pop(0)
while len(samples):
if output_offset[0] == output_buffer.shape[2]:
pending_samples = [(idx, samples, ended)] + pending_samples
run_up_state.pending_samples.extend(pending_samples)
return output_buffer
sample1 = samples.pop(0)
self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset)
if run_up_state.output_frames is None and not run_up_state.pending_samples:
clear_temporal_cache_state(self)
return output_buffer[:, :, :output_offset[0]]
except Exception:
clear_temporal_cache_state(self)
raise
def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
):
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
max_chunk_size = get_max_chunk_size(sample.device)
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return output_buffer
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
clear_temporal_cache_state(self)
class UNetMidBlock3D(nn.Module):
@ -1302,6 +1388,15 @@ class VideoVAE(nn.Module):
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode_start(self, x):
clear_temporal_cache_state(self.decoder)
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
def decode_chunk(self, output_t: int):
return self.decoder.forward_resume(output_t)
def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x

View File

@ -1014,6 +1014,17 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_output_shape(self, samples_shape):
self.throw_exception_if_invalid()
if hasattr(self.first_stage_model, "decode_output_shape"):
return self.first_stage_model.decode_output_shape(samples_shape)
raise RuntimeError("This VAE does not expose decode output shape information.")
def decode_stream_start(self, samples_in):
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype))
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile

View File

@ -2,6 +2,7 @@
from comfy_api.latest._input import (
ImageInput,
AudioInput,
ImageStreamInput,
MaskInput,
LatentInput,
VideoInput,
@ -14,6 +15,7 @@ from comfy_api.latest._input import (
__all__ = [
"ImageInput",
"AudioInput",
"ImageStreamInput",
"MaskInput",
"LatentInput",
"VideoInput",

View File

@ -0,0 +1,6 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input.image_stream_types import ImageStreamInput
__all__ = [
"ImageStreamInput",
]

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input import ImageInput, AudioInput, ImageStreamInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from . import _io_public as io
@ -131,6 +131,7 @@ class ComfyExtension(ABC):
class Input:
Image = ImageInput
Audio = AudioInput
ImageStream = ImageStreamInput
Mask = MaskInput
Latent = LatentInput
Video = VideoInput

View File

@ -1,10 +1,12 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .image_stream_types import ImageStreamInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"ImageStreamInput",
"VideoInput",
"MaskInput",
"LatentInput",

View File

@ -0,0 +1,67 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import nullcontext
from comfy_execution.utils import CurrentNodeContext, get_executing_context
from comfy_execution.progress import get_progress_state
from .basic_types import ImageInput
class ImageStreamInput(ABC):
"""Abstract base class for pull-based image stream inputs.
Consumers request up to ``max_frames`` frames at a time. Producers must not
over-return; a batch with fewer than ``max_frames`` frames signals EOF.
"""
def __init__(self):
#Subclasses must call this init for future core ComfyUI change compatibilty
self._ctx = get_executing_context()
def reset(self) -> None:
#This API is final. Subclasses must NOT override this for future core ComfyUI
#change compatability. Override do_reset instead.
with (nullcontext() if self._ctx is None else
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
self.do_reset()
if self._ctx is not None:
get_progress_state().finish_progress(self._ctx.node_id)
def pull(self, max_frames: int) -> ImageInput:
#This API is final. Subclasses must NOT override this for future core ComfyUI
#change compatability. Override do_pull instead.
with (nullcontext() if self._ctx is None else
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
result = self.do_pull(max_frames)
if self._ctx is not None:
registry = get_progress_state()
entry = registry.nodes.get(self._ctx.node_id)
if (int(result.shape[0]) < max_frames or
(entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])):
registry.finish_progress(self._ctx.node_id)
return result
@abstractmethod
def get_dimensions(self) -> tuple[int, int]:
"""Return the stream frame dimensions as ``(width, height)``."""
pass
@abstractmethod
def do_reset(self) -> None:
"""Reset the stream so the next pull starts from frame 0."""
pass
@abstractmethod
def do_pull(self, max_frames: int) -> ImageInput:
"""Return up to ``max_frames`` images.
The returned tensor uses the normal ``IMAGE`` batch shape. A short
return, where the batch dimension is less than ``max_frames``, is the
EOF signal. Sources are expected to short-return at least once before
exhaustion, including returning an empty batch.
"""
pass

View File

@ -386,6 +386,7 @@ class VideoFromComponents(VideoInput):
def __init__(self, components: VideoComponents):
self.__components = components
self._frame_counter = 0
def get_components(self) -> VideoComponents:
return VideoComponents(
@ -394,14 +395,13 @@ class VideoFromComponents(VideoInput):
frame_rate=self.__components.frame_rate,
)
def save_to(
def save_start(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -413,7 +413,12 @@ class VideoFromComponents(VideoInput):
# BytesIO has no file extension, so av.open can't infer the format.
# Default to mp4 since that's the only supported format anyway.
extra_kwargs["format"] = "mp4"
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
width, height = self.get_dimensions()
output = av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs)
if True:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
@ -422,8 +427,8 @@ class VideoFromComponents(VideoInput):
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.width = width
video_stream.height = height
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
@ -432,23 +437,33 @@ class VideoFromComponents(VideoInput):
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
waveform = waveform[0]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
self._frame_counter = 0
return output, video_stream, audio_stream, audio_sample_rate, frame_rate
def save_add(self, output, video_stream, images) -> None:
# Encode video
for i, frame in enumerate(self.__components.images):
for frame in images:
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
self._frame_counter += 1
def save_finalize(self, output, video_stream, audio_stream, audio_sample_rate, frame_rate) -> None:
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform']
waveform = waveform[0]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
waveform = waveform[:, :math.ceil((audio_sample_rate / frame_rate) * self._frame_counter)]
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
@ -457,6 +472,29 @@ class VideoFromComponents(VideoInput):
# Flush encoder
output.mux(audio_stream.encode(None))
output.close()
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
output, video_stream, audio_stream, audio_sample_rate, frame_rate = self.save_start(
path,
format=format,
codec=codec,
metadata=metadata,
)
try:
self.save_add(output, video_stream, self.__components.images)
self.save_finalize(output, video_stream, audio_stream, audio_sample_rate, frame_rate)
except Exception:
output.close()
raise
def as_trimmed(
self,
start_time: float | None = None,

View File

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from comfy.samplers import CFGGuider, Sampler
from comfy.sd import CLIP, VAE
from comfy.sd import StyleModel as StyleModel_
from comfy_api.input import VideoInput, CurveInput as CurveInput_
from comfy_api.input import ImageStreamInput, VideoInput, CurveInput as CurveInput_
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
@ -420,6 +420,12 @@ class Image(ComfyTypeIO):
Type = torch.Tensor
@comfytype(io_type="IMAGE_STREAM")
class ImageStream(ComfyTypeIO):
if TYPE_CHECKING:
Type = ImageStreamInput
@comfytype(io_type="WAN_CAMERA_EMBEDDING")
class WanCameraEmbedding(ComfyTypeIO):
Type = torch.Tensor
@ -2203,6 +2209,7 @@ __all__ = [
"Combo",
"MultiCombo",
"Image",
"ImageStream",
"WanCameraEmbedding",
"Webcam",
"Mask",

View File

@ -0,0 +1,320 @@
from __future__ import annotations
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, io, ui
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from server import PromptServer
class FrameProgressTracker:
def __init__(self):
self._last_reported=None
def emit(self, frames_processed):
if frames_processed == self._last_reported:
return
current = get_executing_context()
server = getattr(PromptServer, "instance", None)
if current is None or server is None or server.client_id is None:
return
server.send_progress_text(
f"processed {frames_processed} frames",
current.node_id,
server.client_id,
)
self._last_reported = frames_processed
def drain_image_stream(stream, chunk_size, progress):
stream.reset()
frames_processed = 0
while True:
progress.emit(frames_processed)
chunk = stream.pull(chunk_size)
frames_processed += int(chunk.shape[0])
if chunk.shape[0] < chunk_size:
progress.emit(frames_processed)
return frames_processed
class TensorImageStream(Input.ImageStream):
"""Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor."""
def __init__(self, images: Input.Image):
super().__init__()
self._images = images
self._index = 0
self._total_frames = int(images.shape[0])
self._progress_started = False
def _update_progress(self, value: int) -> None:
current = get_executing_context()
if current is None:
return
get_progress_state().update_progress(
current.node_id,
value=float(value),
max_value=float(max(self._total_frames, 1)),
)
def get_dimensions(self) -> tuple[int, int]:
return self._images.shape[2], self._images.shape[1]
def do_reset(self) -> None:
self._index = 0
self._progress_started = False
def do_pull(self, max_frames: int) -> Input.Image:
if not self._progress_started:
self._update_progress(0)
self._progress_started = True
start = self._index
end = min(start + max_frames, self._images.shape[0])
self._index = end
chunk = self._images[start:end].clone()
self._update_progress(end)
return chunk
class PreviewingImageStream(Input.ImageStream):
def __init__(self, stream: Input.ImageStream):
super().__init__()
self._stream = stream
def _emit_preview(self, chunk: Input.Image) -> None:
if int(chunk.shape[0]) == 0:
return
current = get_executing_context()
if current is None:
return
server = getattr(PromptServer, "instance", None)
if server is None or server.client_id is None:
return
preview_output = ui.PreviewImage(chunk[-1:]).as_dict()
server.send_sync(
"executed",
{
"node": current.node_id,
"display_node": current.node_id,
"output": preview_output,
"prompt_id": current.prompt_id,
},
server.client_id,
)
def get_dimensions(self) -> tuple[int, int]:
return self._stream.get_dimensions()
def do_reset(self) -> None:
self._stream.reset()
def do_pull(self, max_frames: int) -> Input.Image:
chunk = self._stream.pull(max_frames)
self._emit_preview(chunk)
return chunk
class VAEDecodedImageStream(Input.ImageStream):
def __init__(self, vae, latent: Input.Latent):
super().__init__()
self._vae = vae
self._latent = latent
vae.throw_exception_if_invalid()
if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False):
raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.")
if latent.ndim != 5:
raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].")
if latent.shape[0] != 1:
raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.")
output_shape = vae.decode_output_shape(latent.shape)
self._channels = int(output_shape[1])
self._width = int(output_shape[4])
self._height = int(output_shape[3])
self._total_frames = int(output_shape[0] * output_shape[2])
self._frames_emitted = 0
def _update_progress(self, value: int) -> None:
current = get_executing_context()
if current is None:
return
get_progress_state().update_progress(
current.node_id,
value=float(value),
max_value=float(max(self._total_frames, 1)),
)
def get_dimensions(self) -> tuple[int, int]:
return self._width, self._height
def do_reset(self) -> None:
self._frames_emitted = 0
self._update_progress(0)
self._vae.decode_stream_start(self._latent)
def do_pull(self, max_frames: int) -> Input.Image:
chunk = self._vae.first_stage_model.decode_chunk(max_frames)
if chunk is None:
return torch.empty(
(0, self._height, self._width, self._channels),
device=self._vae.output_device,
dtype=self._vae.vae_output_dtype(),
)
chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype())
chunk = self._vae.process_output(chunk).movedim(1, -1)
chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:]))
self._frames_emitted += int(chunk.shape[0])
self._update_progress(self._frames_emitted)
return chunk
class ImageBatchToStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageBatchToStream",
display_name="Image Batch To Stream",
category="image/stream",
search_aliases=["image to stream", "batch to stream", "frames to stream"],
description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.",
inputs=[
io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."),
],
outputs=[
io.ImageStream.Output(display_name="stream"),
],
)
@classmethod
def execute(cls, image: Input.Image) -> io.NodeOutput:
return io.NodeOutput(TensorImageStream(image))
class ImageStreamToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageStreamToBatch",
display_name="Image Stream To Batch",
category="image/stream",
search_aliases=["stream to image", "stream to batch", "collect stream"],
description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.",
inputs=[
io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."),
io.Int.Input("batch_size", default=4096, min=1, max=4096),
],
outputs=[
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput:
chunks: list[Input.Image] = []
stream.reset()
while True:
chunk = stream.pull(batch_size)
chunks.append(chunk)
if chunk.shape[0] < batch_size:
break
return io.NodeOutput(torch.cat(chunks, dim=0))
class VAEDecodeStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAEDecodeStream",
display_name="VAE Decode Stream",
category="image/stream",
search_aliases=["vae stream decode", "latent to stream", "video latent stream"],
description="Decodes a latent into an IMAGE_STREAM.",
inputs=[
io.Latent.Input("samples", tooltip="The LTX latent to decode."),
io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."),
],
outputs=[
io.ImageStream.Output(display_name="stream"),
],
)
@classmethod
def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput:
latent = samples["samples"]
if latent.is_nested:
latent = latent.unbind()[0]
return io.NodeOutput(VAEDecodedImageStream(vae, latent))
class PreviewImageStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PreviewImageStream",
display_name="Preview Image Stream",
category="image/stream",
search_aliases=["stream preview", "preview frames", "preview image stream"],
description="Passes an IMAGE_STREAM through while previewing the last frame from each pulled chunk.",
has_intermediate_output=True,
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to preview inline."),
],
outputs=[
io.ImageStream.Output(display_name="passthrough"),
],
)
@classmethod
def execute(cls, stream: Input.ImageStream) -> io.NodeOutput:
return io.NodeOutput(PreviewingImageStream(stream))
class StreamSink(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StreamSink",
search_aliases=["consume stream", "drain stream", "image stream sink"],
display_name="Stream Sink",
category="image/stream",
description="Consumes an IMAGE_STREAM by pulling it to EOF.",
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to consume."),
io.Int.Input("chunk_size", default=8, min=1, max=4096),
],
outputs=[],
is_output_node=True,
)
@classmethod
def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput:
drain_image_stream(stream, chunk_size, progress=FrameProgressTracker())
return io.NodeOutput()
class ImageStreamExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageBatchToStream,
ImageStreamToBatch,
VAEDecodeStream,
PreviewImageStream,
StreamSink,
]
async def comfy_entrypoint() -> ImageStreamExtension:
return ImageStreamExtension()

View File

@ -5,8 +5,7 @@ import torch
import comfy.utils
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
from comfy_api.latest import ComfyExtension, Input, io
try:
from spandrel_extra_arches import EXTRA_REGISTRY
@ -47,9 +46,29 @@ class UpscaleModelLoader(io.ComfyNode):
load_model = execute # TODO: remove
class UpscaledImageStream(Input.ImageStream):
def __init__(self, upscale_model, stream: Input.ImageStream):
super().__init__()
self._upscale_model = upscale_model
self._stream = stream
def get_dimensions(self) -> tuple[int, int]:
width, height = self._stream.get_dimensions()
scale = self._upscale_model.scale
return int(width * scale), int(height * scale)
def do_reset(self) -> None:
self._stream.reset()
def do_pull(self, max_frames: int) -> Input.Image:
chunk = self._stream.pull(max_frames)
return ImageUpscaleWithModel.upscale_batch(self._upscale_model, chunk)
class ImageUpscaleWithModel(io.ComfyNode):
@classmethod
def define_schema(cls):
image_template = io.MatchType.Template("image_type", allowed_types=[io.Image, io.ImageStream])
return io.Schema(
node_id="ImageUpscaleWithModel",
display_name="Upscale Image (using Model)",
@ -57,15 +76,18 @@ class ImageUpscaleWithModel(io.ComfyNode):
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
io.MatchType.Input("image", template=image_template),
],
outputs=[
io.Image.Output(),
io.MatchType.Output(template=image_template, display_name="image"),
],
)
@classmethod
def execute(cls, upscale_model, image) -> io.NodeOutput:
def upscale_batch(cls, upscale_model, image: torch.Tensor) -> torch.Tensor:
if image.shape[0] == 0:
return image.clone()
device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model.model)
@ -79,7 +101,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
tile = 512
overlap = 32
output_device = comfy.model_management.intermediate_device()
output_device = model_management.intermediate_device()
oom = True
try:
@ -97,8 +119,14 @@ class ImageUpscaleWithModel(io.ComfyNode):
finally:
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
return io.NodeOutput(s)
return torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(model_management.intermediate_dtype())
@classmethod
def execute(cls, upscale_model, image) -> io.NodeOutput:
if isinstance(image, torch.Tensor):
return io.NodeOutput(cls.upscale_batch(upscale_model, image))
return io.NodeOutput(UpscaledImageStream(upscale_model, image))
upscale = execute # TODO: remove

View File

@ -5,11 +5,162 @@ import av
import torch
import folder_paths
import json
from typing import Optional
from typing import Callable, Optional
from typing_extensions import override
from fractions import Fraction
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
from comfy_execution.utils import get_executing_context
from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream
from server import PromptServer
class SavedVideoStream(Input.ImageStream):
def __init__(
self,
stream: Input.ImageStream,
saver,
output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]],
format: str,
codec,
metadata: Optional[dict],
emit_preview_on_finalize: bool = True,
preview_node_id: Optional[str] = None,
preview_display_node_id: Optional[str] = None,
):
super().__init__()
self._stream = stream
self._saver = saver
self._output_factory = output_factory
self._path: str | None = None
self._format = Types.VideoContainer(format)
self._codec = Types.VideoCodec(codec)
self._metadata = metadata
self._preview_ui: ui.PreviewVideo | None = None
self._emit_preview_on_finalize = emit_preview_on_finalize
self._preview_node_id = preview_node_id
self._preview_display_node_id = preview_display_node_id
self._save_state = None
def _emit_preview(self) -> None:
if not self._emit_preview_on_finalize or self._preview_ui is None:
return
current = get_executing_context()
if current is None:
return
server = getattr(PromptServer, "instance", None)
if server is None or server.client_id is None:
return
server.send_sync(
"executed",
{
"node": self._preview_node_id or current.node_id,
"display_node": self._preview_display_node_id or self._preview_node_id or current.node_id,
"output": self._preview_ui.as_dict(),
"prompt_id": current.prompt_id,
},
server.client_id,
)
def _discard_partial_output(self) -> None:
if self._save_state is not None:
self._save_state[0].close()
self._save_state = None
if self._path is not None and os.path.exists(self._path):
os.remove(self._path)
self._path = None
self._preview_ui = None
def get_preview_ui(self) -> ui.PreviewVideo | None:
return self._preview_ui
def get_dimensions(self) -> tuple[int, int]:
return self._stream.get_dimensions()
def do_reset(self) -> None:
self._discard_partial_output()
self._stream.reset()
self._path, self._preview_ui = self._output_factory()
assert self._path is not None
open(self._path, "ab").close()
self._save_state = self._saver.save_start(
self._path,
format=self._format,
codec=self._codec,
metadata=self._metadata,
)
def do_pull(self, max_frames: int) -> Input.Image:
assert self._save_state is not None
chunk = self._stream.pull(max_frames)
self._saver.save_add(self._save_state[0], self._save_state[1], chunk)
if chunk.shape[0] < max_frames:
self._saver.save_finalize(*self._save_state)
self._save_state = None
self._emit_preview()
self._path = None
return chunk
def _build_saved_stream(
hidden,
stream: Input.ImageStream,
audio: Optional[Input.Audio],
fps: float,
filename_prefix,
format: str,
codec,
emit_preview: bool = True,
) -> SavedVideoStream:
width, height = stream.get_dimensions()
saved_metadata = None
if not args.disable_metadata:
metadata = {}
if hidden.extra_pnginfo is not None:
metadata.update(hidden.extra_pnginfo)
if hidden.prompt is not None:
metadata["prompt"] = hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
preview_node_id = hidden.unique_id
preview_display_node_id = preview_node_id
if hidden.dynprompt is not None and preview_node_id is not None:
preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id)
def output_factory() -> tuple[str, ui.PreviewVideo]:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix,
folder_paths.get_output_directory(),
width,
height,
)
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
return (
os.path.join(full_output_folder, file),
ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]),
)
return SavedVideoStream(
stream,
InputImpl.VideoFromComponents(
Types.VideoComponents(
images=torch.zeros((0, height, width, 3)),
audio=audio,
frame_rate=Fraction(fps),
)
),
output_factory,
format,
codec,
saved_metadata,
emit_preview,
preview_node_id,
preview_display_node_id,
)
class SaveWEBM(io.ComfyNode):
@classmethod
@ -114,6 +265,93 @@ class SaveVideo(io.ComfyNode):
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
class SavePassthroughVideoStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SavePassthroughVideoStream",
search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"],
display_name="Save+Passthrough Video Stream",
category="image/video",
essentials_category="Basics",
description="Saves frames as they pass through the input image stream.",
has_intermediate_output=True,
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to save."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
],
outputs=[
io.ImageStream.Output(display_name="passthrough"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
)
@classmethod
def execute(
cls,
stream: Input.ImageStream,
fps: float,
filename_prefix,
format: str,
codec,
audio: Optional[Input.Audio] = None,
) -> io.NodeOutput:
return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec))
class SaveVideoStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveVideoStream",
search_aliases=["save image stream", "export video stream", "stream to video"],
display_name="Save Video Stream",
category="image/video",
essentials_category="Basics",
description="Saves an image stream by draining it directly to EOF.",
is_output_node=True,
inputs=[
io.ImageStream.Input("stream", tooltip="The image stream to save."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Int.Input("chunk_size", default=8, min=1, max=4096),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
)
@classmethod
def execute(
cls,
stream: Input.ImageStream,
fps: float,
chunk_size: int,
filename_prefix,
format: str,
codec,
audio: Optional[Input.Audio] = None,
) -> io.NodeOutput:
saved_stream = _build_saved_stream(
cls.hidden,
stream,
audio,
fps,
filename_prefix,
format,
codec,
emit_preview=False,
)
drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker())
return io.NodeOutput(ui=saved_stream.get_preview_ui())
class CreateVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -262,6 +500,8 @@ class VideoExtension(ComfyExtension):
return [
SaveWEBM,
SaveVideo,
SavePassthroughVideoStream,
SaveVideoStream,
CreateVideo,
GetVideoComponents,
LoadVideo,

View File

@ -2414,6 +2414,7 @@ async def init_builtin_extra_nodes():
"nodes_hooks.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_image_stream.py",
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",