diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 57126fa4a..18eba6dc5 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -51,6 +51,7 @@ class IO(StrEnum): BBOX = "BBOX" SEGS = "SEGS" VIDEO = "VIDEO" + IMAGE_STREAM = "IMAGE_STREAM" ANY = "*" """Always matches any type, but at a price. diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 998122c85..91606ffa6 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -16,6 +16,17 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +class RunUpState: + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None): + self.timestep_shift_scale = timestep_shift_scale + self.scaled_timestep = scaled_timestep + self.checkpoint_fn = checkpoint_fn + self.max_chunk_size = max_chunk_size + self.output_shape = output_shape + self.output_dtype = output_dtype + self.output_frames = output_frames + self.pending_samples = [] + def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -26,6 +37,14 @@ def mark_conv3d_ended(module): current = m.temporal_cache_state.get(tid, (None, False)) m.temporal_cache_state[tid] = (current[0], True) +def clear_temporal_cache_state(module): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + for _, m in module.named_modules(): + if hasattr(m, "temporal_cache_state"): + m.temporal_cache_state.pop(tid, None) + def split2(tensor, split_point, dim=2): return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) @@ -315,13 +334,7 @@ class Encoder(nn.Module): try: return self.forward_orig(*args, **kwargs) finally: - tid = threading.get_ident() - for _, module in self.named_modules(): - # ComfyUI doesn't thread this kind of stuff today, but just in case - # we key on the thread to make it thread safe. - tid = threading.get_ident() - if hasattr(module, "temporal_cache_state"): - module.temporal_cache_state.pop(tid, None) + clear_temporal_cache_state(self) MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 @@ -530,19 +543,20 @@ class Decoder(nn.Module): ).unsqueeze(1).expand(2, output_channel), persistent=False, ) + self.temporal_cache_state = {} def decode_output_shape(self, input_shape): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) - def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset): sample = sample_ref[0] sample_ref[0] = None if idx >= len(self.up_blocks): sample = self.conv_norm_out(sample) - if timestep_shift_scale is not None: - shift, scale = timestep_shift_scale + if run_up_state.timestep_shift_scale is not None: + shift, scale = run_up_state.timestep_shift_scale sample = sample * (1 + scale) + shift sample = self.conv_act(sample) if ended: @@ -550,38 +564,49 @@ class Decoder(nn.Module): sample = self.conv_out(sample, causal=self.causal) if sample is not None and sample.shape[2] > 0: sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - t = sample.shape[2] - output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + if output_buffer is None: + run_up_state.output_frames = sample + return + output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]] + t = output_slice.shape[2] + output_slice.copy_(sample[:, :, :t]) output_offset[0] += t + if t < sample.shape[2]: + run_up_state.output_frames = sample[:, :, t:] return up_block = self.up_blocks[idx] if ended: mark_conv3d_ended(up_block) if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep + sample = run_up_state.checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=run_up_state.scaled_timestep ) else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) + sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal) if sample is None or sample.shape[2] == 0: return total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size if num_chunks == 1: # when we are not chunking, detach our x so the callee can free it as soon as they are done next_sample_ref = [sample] del sample - self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + #Just let this run_up unconditionally regardless of, its ok because either a lower layer + #chunker or output frame stash will do the work anyway. so unchanged. + self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset) return else: - samples = torch.chunk(sample, chunks=num_chunks, dim=2) + samples = list(torch.chunk(sample, chunks=num_chunks, dim=2)) - for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + while len(samples): + if output_buffer is None or output_offset[0] == output_buffer.shape[2]: + run_up_state.pending_samples.append((idx + 1, samples, ended)) + return + self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset) def forward_orig( self, @@ -591,6 +616,7 @@ class Decoder(nn.Module): ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + output_shape = self.decode_output_shape(sample.shape) mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) @@ -630,29 +656,89 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) + output_offset = [0] + + run_up_state = RunUpState( + timestep_shift_scale=timestep_shift_scale, + scaled_timestep=scaled_timestep, + checkpoint_fn=checkpoint_fn, + max_chunk_size=get_max_chunk_size(sample.device), + output_shape=output_shape, + output_dtype=sample.dtype, + ) + self.temporal_cache_state[threading.get_ident()] = run_up_state + + self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset) + + return output_buffer + + def forward_start( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + ): + try: + return self.forward_orig(sample, timestep=timestep, output_buffer=None) + except Exception: + clear_temporal_cache_state(self) + raise + + def forward_resume(self, output_t: int): + tid = threading.get_ident() + run_up_state = self.temporal_cache_state.get(tid, None) + if run_up_state is None: + return None + + output_shape = list(run_up_state.output_shape) + output_shape[2] = output_t + output_buffer = torch.empty( + output_shape, + dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] + + try: + if run_up_state.output_frames is not None: + output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]] + t = output_slice.shape[2] + output_slice.copy_(run_up_state.output_frames[:, :, :t]) + output_offset[0] += t + run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:] + + pending_samples = run_up_state.pending_samples + run_up_state.pending_samples = [] + while len(pending_samples): + idx, samples, ended = pending_samples.pop(0) + while len(samples): + if output_offset[0] == output_buffer.shape[2]: + pending_samples = [(idx, samples, ended)] + pending_samples + run_up_state.pending_samples.extend(pending_samples) + return output_buffer + sample1 = samples.pop(0) + self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset) + + if run_up_state.output_frames is None and not run_up_state.pending_samples: + clear_temporal_cache_state(self) + return output_buffer[:, :, :output_offset[0]] + except Exception: + clear_temporal_cache_state(self) + raise + + def forward( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, + ): if output_buffer is None: output_buffer = torch.empty( self.decode_output_shape(sample.shape), dtype=sample.dtype, device=comfy.model_management.intermediate_device(), ) - output_offset = [0] - - max_chunk_size = get_max_chunk_size(sample.device) - - self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) - - return output_buffer - - def forward(self, *args, **kwargs): try: - return self.forward_orig(*args, **kwargs) + return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer) finally: - for _, module in self.named_modules(): - #ComfyUI doesn't thread this kind of stuff today, but just incase - #we key on the thread to make it thread safe. - tid = threading.get_ident() - if hasattr(module, "temporal_cache_state"): - module.temporal_cache_state.pop(tid, None) + clear_temporal_cache_state(self) class UNetMidBlock3D(nn.Module): @@ -1302,6 +1388,15 @@ class VideoVAE(nn.Module): def decode_output_shape(self, input_shape): return self.decoder.decode_output_shape(input_shape) + def decode_start(self, x): + clear_temporal_cache_state(self.decoder) + if self.timestep_conditioning: #TODO: seed + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) + + def decode_chunk(self, output_t: int): + return self.decoder.forward_resume(output_t) + def decode(self, x, output_buffer=None): if self.timestep_conditioning: #TODO: seed x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x diff --git a/comfy/sd.py b/comfy/sd.py index 736fe35de..28b24ff89 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1014,6 +1014,17 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples + def decode_output_shape(self, samples_shape): + self.throw_exception_if_invalid() + if hasattr(self.first_stage_model, "decode_output_shape"): + return self.first_stage_model.decode_output_shape(samples_shape) + raise RuntimeError("This VAE does not expose decode output shape information.") + + def decode_stream_start(self, samples_in): + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype)) + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 16d4acfd1..8e2374aaf 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -2,6 +2,7 @@ from comfy_api.latest._input import ( ImageInput, AudioInput, + ImageStreamInput, MaskInput, LatentInput, VideoInput, @@ -14,6 +15,7 @@ from comfy_api.latest._input import ( __all__ = [ "ImageInput", "AudioInput", + "ImageStreamInput", "MaskInput", "LatentInput", "VideoInput", diff --git a/comfy_api/input/image_stream_types.py b/comfy_api/input/image_stream_types.py new file mode 100644 index 000000000..b52d0c76d --- /dev/null +++ b/comfy_api/input/image_stream_types.py @@ -0,0 +1,6 @@ +# This file only exists for backwards compatibility. +from comfy_api.latest._input.image_stream_types import ImageStreamInput + +__all__ = [ + "ImageStreamInput", +] diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 04973fea0..c0493b3ca 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class -from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input import ImageInput, AudioInput, ImageStreamInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D from . import _io_public as io @@ -131,6 +131,7 @@ class ComfyExtension(ABC): class Input: Image = ImageInput Audio = AudioInput + ImageStream = ImageStreamInput Mask = MaskInput Latent = LatentInput Video = VideoInput diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 05cd3d40a..3ec611879 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,10 +1,12 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve +from .image_stream_types import ImageStreamInput from .video_types import VideoInput __all__ = [ "ImageInput", "AudioInput", + "ImageStreamInput", "VideoInput", "MaskInput", "LatentInput", diff --git a/comfy_api/latest/_input/image_stream_types.py b/comfy_api/latest/_input/image_stream_types.py new file mode 100644 index 000000000..a582859ad --- /dev/null +++ b/comfy_api/latest/_input/image_stream_types.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import nullcontext + +from comfy_execution.utils import CurrentNodeContext, get_executing_context +from comfy_execution.progress import get_progress_state +from .basic_types import ImageInput + + +class ImageStreamInput(ABC): + """Abstract base class for pull-based image stream inputs. + + Consumers request up to ``max_frames`` frames at a time. Producers must not + over-return; a batch with fewer than ``max_frames`` frames signals EOF. + """ + + def __init__(self): + #Subclasses must call this init for future core ComfyUI change compatibilty + self._ctx = get_executing_context() + + def reset(self) -> None: + #This API is final. Subclasses must NOT override this for future core ComfyUI + #change compatability. Override do_reset instead. + with (nullcontext() if self._ctx is None else + CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)): + self.do_reset() + + if self._ctx is not None: + get_progress_state().finish_progress(self._ctx.node_id) + + def pull(self, max_frames: int) -> ImageInput: + #This API is final. Subclasses must NOT override this for future core ComfyUI + #change compatability. Override do_pull instead. + with (nullcontext() if self._ctx is None else + CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)): + result = self.do_pull(max_frames) + + if self._ctx is not None: + registry = get_progress_state() + entry = registry.nodes.get(self._ctx.node_id) + if (int(result.shape[0]) < max_frames or + (entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])): + registry.finish_progress(self._ctx.node_id) + + return result + + @abstractmethod + def get_dimensions(self) -> tuple[int, int]: + """Return the stream frame dimensions as ``(width, height)``.""" + pass + + @abstractmethod + def do_reset(self) -> None: + """Reset the stream so the next pull starts from frame 0.""" + pass + + @abstractmethod + def do_pull(self, max_frames: int) -> ImageInput: + """Return up to ``max_frames`` images. + + The returned tensor uses the normal ``IMAGE`` batch shape. A short + return, where the batch dimension is less than ``max_frames``, is the + EOF signal. Sources are expected to short-return at least once before + exhaustion, including returning an empty batch. + """ + pass diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 1b4993aa7..4cde7763a 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -386,6 +386,7 @@ class VideoFromComponents(VideoInput): def __init__(self, components: VideoComponents): self.__components = components + self._frame_counter = 0 def get_components(self) -> VideoComponents: return VideoComponents( @@ -394,14 +395,13 @@ class VideoFromComponents(VideoInput): frame_rate=self.__components.frame_rate, ) - def save_to( + def save_start( self, path: str, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): - """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -413,7 +413,12 @@ class VideoFromComponents(VideoInput): # BytesIO has no file extension, so av.open can't infer the format. # Default to mp4 since that's the only supported format anyway. extra_kwargs["format"] = "mp4" - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: + + width, height = self.get_dimensions() + + output = av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) + if True: + # Add metadata before writing any streams if metadata is not None: for key, value in metadata.items(): @@ -422,8 +427,8 @@ class VideoFromComponents(VideoInput): frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) # Create a video stream video_stream = output.add_stream('h264', rate=frame_rate) - video_stream.width = self.__components.images.shape[2] - video_stream.height = self.__components.images.shape[1] + video_stream.width = width + video_stream.height = height video_stream.pix_fmt = 'yuv420p' # Create an audio stream @@ -432,23 +437,33 @@ class VideoFromComponents(VideoInput): if self.__components.audio: audio_sample_rate = int(self.__components.audio['sample_rate']) waveform = self.__components.audio['waveform'] - waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + waveform = waveform[0] layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout) + self._frame_counter = 0 + return output, video_stream, audio_stream, audio_sample_rate, frame_rate + + def save_add(self, output, video_stream, images) -> None: # Encode video - for i, frame in enumerate(self.__components.images): + for frame in images: img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) frame = av.VideoFrame.from_ndarray(img, format='rgb24') frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 packet = video_stream.encode(frame) output.mux(packet) + self._frame_counter += 1 + def save_finalize(self, output, video_stream, audio_stream, audio_sample_rate, frame_rate) -> None: # Flush video packet = video_stream.encode(None) output.mux(packet) if audio_stream and self.__components.audio: + waveform = self.__components.audio['waveform'] + waveform = waveform[0] + layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') + waveform = waveform[:, :math.ceil((audio_sample_rate / frame_rate) * self._frame_counter)] frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout) frame.sample_rate = audio_sample_rate frame.pts = 0 @@ -457,6 +472,29 @@ class VideoFromComponents(VideoInput): # Flush encoder output.mux(audio_stream.encode(None)) + output.close() + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None, + ): + """Save the video to a file path or BytesIO buffer.""" + output, video_stream, audio_stream, audio_sample_rate, frame_rate = self.save_start( + path, + format=format, + codec=codec, + metadata=metadata, + ) + try: + self.save_add(output, video_stream, self.__components.images) + self.save_finalize(output, video_stream, audio_stream, audio_sample_rate, frame_rate) + except Exception: + output.close() + raise + def as_trimmed( self, start_time: float | None = None, diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index fdeffea2d..3fbd9e40e 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from comfy.samplers import CFGGuider, Sampler from comfy.sd import CLIP, VAE from comfy.sd import StyleModel as StyleModel_ - from comfy_api.input import VideoInput, CurveInput as CurveInput_ + from comfy_api.input import ImageStreamInput, VideoInput, CurveInput as CurveInput_ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker @@ -420,6 +420,12 @@ class Image(ComfyTypeIO): Type = torch.Tensor +@comfytype(io_type="IMAGE_STREAM") +class ImageStream(ComfyTypeIO): + if TYPE_CHECKING: + Type = ImageStreamInput + + @comfytype(io_type="WAN_CAMERA_EMBEDDING") class WanCameraEmbedding(ComfyTypeIO): Type = torch.Tensor @@ -2203,6 +2209,7 @@ __all__ = [ "Combo", "MultiCombo", "Image", + "ImageStream", "WanCameraEmbedding", "Webcam", "Mask", diff --git a/comfy_extras/nodes_image_stream.py b/comfy_extras/nodes_image_stream.py new file mode 100644 index 000000000..f01075887 --- /dev/null +++ b/comfy_extras/nodes_image_stream.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, Input, io, ui +from comfy_execution.progress import get_progress_state +from comfy_execution.utils import get_executing_context +from server import PromptServer + + +class FrameProgressTracker: + def __init__(self): + self._last_reported=None + + def emit(self, frames_processed): + if frames_processed == self._last_reported: + return + + current = get_executing_context() + server = getattr(PromptServer, "instance", None) + if current is None or server is None or server.client_id is None: + return + + server.send_progress_text( + f"processed {frames_processed} frames", + current.node_id, + server.client_id, + ) + self._last_reported = frames_processed + + +def drain_image_stream(stream, chunk_size, progress): + stream.reset() + frames_processed = 0 + + while True: + progress.emit(frames_processed) + chunk = stream.pull(chunk_size) + frames_processed += int(chunk.shape[0]) + if chunk.shape[0] < chunk_size: + progress.emit(frames_processed) + return frames_processed + +class TensorImageStream(Input.ImageStream): + """Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor.""" + + def __init__(self, images: Input.Image): + super().__init__() + self._images = images + self._index = 0 + self._total_frames = int(images.shape[0]) + self._progress_started = False + + def _update_progress(self, value: int) -> None: + current = get_executing_context() + if current is None: + return + get_progress_state().update_progress( + current.node_id, + value=float(value), + max_value=float(max(self._total_frames, 1)), + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._images.shape[2], self._images.shape[1] + + def do_reset(self) -> None: + self._index = 0 + self._progress_started = False + + def do_pull(self, max_frames: int) -> Input.Image: + if not self._progress_started: + self._update_progress(0) + self._progress_started = True + + start = self._index + end = min(start + max_frames, self._images.shape[0]) + self._index = end + chunk = self._images[start:end].clone() + self._update_progress(end) + return chunk + + +class PreviewingImageStream(Input.ImageStream): + def __init__(self, stream: Input.ImageStream): + super().__init__() + self._stream = stream + + def _emit_preview(self, chunk: Input.Image) -> None: + if int(chunk.shape[0]) == 0: + return + + current = get_executing_context() + if current is None: + return + + server = getattr(PromptServer, "instance", None) + if server is None or server.client_id is None: + return + + preview_output = ui.PreviewImage(chunk[-1:]).as_dict() + server.send_sync( + "executed", + { + "node": current.node_id, + "display_node": current.node_id, + "output": preview_output, + "prompt_id": current.prompt_id, + }, + server.client_id, + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._stream.get_dimensions() + + def do_reset(self) -> None: + self._stream.reset() + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._stream.pull(max_frames) + self._emit_preview(chunk) + return chunk + + +class VAEDecodedImageStream(Input.ImageStream): + def __init__(self, vae, latent: Input.Latent): + super().__init__() + self._vae = vae + self._latent = latent + vae.throw_exception_if_invalid() + if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False): + raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.") + if latent.ndim != 5: + raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].") + if latent.shape[0] != 1: + raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.") + output_shape = vae.decode_output_shape(latent.shape) + self._channels = int(output_shape[1]) + self._width = int(output_shape[4]) + self._height = int(output_shape[3]) + self._total_frames = int(output_shape[0] * output_shape[2]) + self._frames_emitted = 0 + + def _update_progress(self, value: int) -> None: + current = get_executing_context() + if current is None: + return + get_progress_state().update_progress( + current.node_id, + value=float(value), + max_value=float(max(self._total_frames, 1)), + ) + + def get_dimensions(self) -> tuple[int, int]: + return self._width, self._height + + def do_reset(self) -> None: + self._frames_emitted = 0 + self._update_progress(0) + self._vae.decode_stream_start(self._latent) + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._vae.first_stage_model.decode_chunk(max_frames) + if chunk is None: + return torch.empty( + (0, self._height, self._width, self._channels), + device=self._vae.output_device, + dtype=self._vae.vae_output_dtype(), + ) + + chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype()) + chunk = self._vae.process_output(chunk).movedim(1, -1) + chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:])) + self._frames_emitted += int(chunk.shape[0]) + self._update_progress(self._frames_emitted) + return chunk + + +class ImageBatchToStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBatchToStream", + display_name="Image Batch To Stream", + category="image/stream", + search_aliases=["image to stream", "batch to stream", "frames to stream"], + description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.", + inputs=[ + io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."), + ], + outputs=[ + io.ImageStream.Output(display_name="stream"), + ], + ) + + @classmethod + def execute(cls, image: Input.Image) -> io.NodeOutput: + return io.NodeOutput(TensorImageStream(image)) + + +class ImageStreamToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageStreamToBatch", + display_name="Image Stream To Batch", + category="image/stream", + search_aliases=["stream to image", "stream to batch", "collect stream"], + description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.", + inputs=[ + io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."), + io.Int.Input("batch_size", default=4096, min=1, max=4096), + ], + outputs=[ + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput: + chunks: list[Input.Image] = [] + stream.reset() + + while True: + chunk = stream.pull(batch_size) + + chunks.append(chunk) + if chunk.shape[0] < batch_size: + break + + return io.NodeOutput(torch.cat(chunks, dim=0)) + + +class VAEDecodeStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VAEDecodeStream", + display_name="VAE Decode Stream", + category="image/stream", + search_aliases=["vae stream decode", "latent to stream", "video latent stream"], + description="Decodes a latent into an IMAGE_STREAM.", + inputs=[ + io.Latent.Input("samples", tooltip="The LTX latent to decode."), + io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."), + ], + outputs=[ + io.ImageStream.Output(display_name="stream"), + ], + ) + + @classmethod + def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput: + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + return io.NodeOutput(VAEDecodedImageStream(vae, latent)) + + +class PreviewImageStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PreviewImageStream", + display_name="Preview Image Stream", + category="image/stream", + search_aliases=["stream preview", "preview frames", "preview image stream"], + description="Passes an IMAGE_STREAM through while previewing the last frame from each pulled chunk.", + has_intermediate_output=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to preview inline."), + ], + outputs=[ + io.ImageStream.Output(display_name="passthrough"), + ], + ) + + @classmethod + def execute(cls, stream: Input.ImageStream) -> io.NodeOutput: + return io.NodeOutput(PreviewingImageStream(stream)) + + +class StreamSink(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StreamSink", + search_aliases=["consume stream", "drain stream", "image stream sink"], + display_name="Stream Sink", + category="image/stream", + description="Consumes an IMAGE_STREAM by pulling it to EOF.", + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to consume."), + io.Int.Input("chunk_size", default=8, min=1, max=4096), + ], + outputs=[], + is_output_node=True, + ) + + @classmethod + def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput: + drain_image_stream(stream, chunk_size, progress=FrameProgressTracker()) + return io.NodeOutput() + + +class ImageStreamExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ImageBatchToStream, + ImageStreamToBatch, + VAEDecodeStream, + PreviewImageStream, + StreamSink, + ] + + +async def comfy_entrypoint() -> ImageStreamExtension: + return ImageStreamExtension() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1..e2ae24422 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -5,8 +5,7 @@ import torch import comfy.utils import folder_paths from typing_extensions import override -from comfy_api.latest import ComfyExtension, io -import comfy.model_management +from comfy_api.latest import ComfyExtension, Input, io try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -47,9 +46,29 @@ class UpscaleModelLoader(io.ComfyNode): load_model = execute # TODO: remove +class UpscaledImageStream(Input.ImageStream): + def __init__(self, upscale_model, stream: Input.ImageStream): + super().__init__() + self._upscale_model = upscale_model + self._stream = stream + + def get_dimensions(self) -> tuple[int, int]: + width, height = self._stream.get_dimensions() + scale = self._upscale_model.scale + return int(width * scale), int(height * scale) + + def do_reset(self) -> None: + self._stream.reset() + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._stream.pull(max_frames) + return ImageUpscaleWithModel.upscale_batch(self._upscale_model, chunk) + + class ImageUpscaleWithModel(io.ComfyNode): @classmethod def define_schema(cls): + image_template = io.MatchType.Template("image_type", allowed_types=[io.Image, io.ImageStream]) return io.Schema( node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", @@ -57,15 +76,18 @@ class ImageUpscaleWithModel(io.ComfyNode): search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), - io.Image.Input("image"), + io.MatchType.Input("image", template=image_template), ], outputs=[ - io.Image.Output(), + io.MatchType.Output(template=image_template, display_name="image"), ], ) @classmethod - def execute(cls, upscale_model, image) -> io.NodeOutput: + def upscale_batch(cls, upscale_model, image: torch.Tensor) -> torch.Tensor: + if image.shape[0] == 0: + return image.clone() + device = model_management.get_torch_device() memory_required = model_management.module_size(upscale_model.model) @@ -79,7 +101,7 @@ class ImageUpscaleWithModel(io.ComfyNode): tile = 512 overlap = 32 - output_device = comfy.model_management.intermediate_device() + output_device = model_management.intermediate_device() oom = True try: @@ -97,8 +119,14 @@ class ImageUpscaleWithModel(io.ComfyNode): finally: upscale_model.to("cpu") - s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) - return io.NodeOutput(s) + return torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(model_management.intermediate_dtype()) + + @classmethod + def execute(cls, upscale_model, image) -> io.NodeOutput: + if isinstance(image, torch.Tensor): + return io.NodeOutput(cls.upscale_batch(upscale_model, image)) + + return io.NodeOutput(UpscaledImageStream(upscale_model, image)) upscale = execute # TODO: remove diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 5c096c232..92fce02cf 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,11 +5,162 @@ import av import torch import folder_paths import json -from typing import Optional +from typing import Callable, Optional from typing_extensions import override from fractions import Fraction from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args +from comfy_execution.utils import get_executing_context +from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream +from server import PromptServer + + +class SavedVideoStream(Input.ImageStream): + def __init__( + self, + stream: Input.ImageStream, + saver, + output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]], + format: str, + codec, + metadata: Optional[dict], + emit_preview_on_finalize: bool = True, + preview_node_id: Optional[str] = None, + preview_display_node_id: Optional[str] = None, + ): + super().__init__() + self._stream = stream + self._saver = saver + self._output_factory = output_factory + self._path: str | None = None + self._format = Types.VideoContainer(format) + self._codec = Types.VideoCodec(codec) + self._metadata = metadata + self._preview_ui: ui.PreviewVideo | None = None + self._emit_preview_on_finalize = emit_preview_on_finalize + self._preview_node_id = preview_node_id + self._preview_display_node_id = preview_display_node_id + self._save_state = None + + def _emit_preview(self) -> None: + if not self._emit_preview_on_finalize or self._preview_ui is None: + return + + current = get_executing_context() + if current is None: + return + + server = getattr(PromptServer, "instance", None) + if server is None or server.client_id is None: + return + + server.send_sync( + "executed", + { + "node": self._preview_node_id or current.node_id, + "display_node": self._preview_display_node_id or self._preview_node_id or current.node_id, + "output": self._preview_ui.as_dict(), + "prompt_id": current.prompt_id, + }, + server.client_id, + ) + + def _discard_partial_output(self) -> None: + if self._save_state is not None: + self._save_state[0].close() + self._save_state = None + if self._path is not None and os.path.exists(self._path): + os.remove(self._path) + self._path = None + self._preview_ui = None + + def get_preview_ui(self) -> ui.PreviewVideo | None: + return self._preview_ui + + def get_dimensions(self) -> tuple[int, int]: + return self._stream.get_dimensions() + + def do_reset(self) -> None: + self._discard_partial_output() + self._stream.reset() + self._path, self._preview_ui = self._output_factory() + assert self._path is not None + open(self._path, "ab").close() + self._save_state = self._saver.save_start( + self._path, + format=self._format, + codec=self._codec, + metadata=self._metadata, + ) + + def do_pull(self, max_frames: int) -> Input.Image: + assert self._save_state is not None + chunk = self._stream.pull(max_frames) + self._saver.save_add(self._save_state[0], self._save_state[1], chunk) + if chunk.shape[0] < max_frames: + self._saver.save_finalize(*self._save_state) + self._save_state = None + self._emit_preview() + self._path = None + return chunk + + +def _build_saved_stream( + hidden, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + filename_prefix, + format: str, + codec, + emit_preview: bool = True, +) -> SavedVideoStream: + width, height = stream.get_dimensions() + saved_metadata = None + if not args.disable_metadata: + metadata = {} + if hidden.extra_pnginfo is not None: + metadata.update(hidden.extra_pnginfo) + if hidden.prompt is not None: + metadata["prompt"] = hidden.prompt + if len(metadata) > 0: + saved_metadata = metadata + + preview_node_id = hidden.unique_id + preview_display_node_id = preview_node_id + if hidden.dynprompt is not None and preview_node_id is not None: + preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id) + + def output_factory() -> tuple[str, ui.PreviewVideo]: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, + folder_paths.get_output_directory(), + width, + height, + ) + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" + return ( + os.path.join(full_output_folder, file), + ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]), + ) + + return SavedVideoStream( + stream, + InputImpl.VideoFromComponents( + Types.VideoComponents( + images=torch.zeros((0, height, width, 3)), + audio=audio, + frame_rate=Fraction(fps), + ) + ), + output_factory, + format, + codec, + saved_metadata, + emit_preview, + preview_node_id, + preview_display_node_id, + ) class SaveWEBM(io.ComfyNode): @classmethod @@ -114,6 +265,93 @@ class SaveVideo(io.ComfyNode): return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) +class SavePassthroughVideoStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SavePassthroughVideoStream", + search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"], + display_name="Save+Passthrough Video Stream", + category="image/video", + essentials_category="Basics", + description="Saves frames as they pass through the input image stream.", + has_intermediate_output=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to save."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + ], + outputs=[ + io.ImageStream.Output(display_name="passthrough"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], + ) + + @classmethod + def execute( + cls, + stream: Input.ImageStream, + fps: float, + filename_prefix, + format: str, + codec, + audio: Optional[Input.Audio] = None, + ) -> io.NodeOutput: + return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec)) + + +class SaveVideoStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideoStream", + search_aliases=["save image stream", "export video stream", "stream to video"], + display_name="Save Video Stream", + category="image/video", + essentials_category="Basics", + description="Saves an image stream by draining it directly to EOF.", + is_output_node=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to save."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.Int.Input("chunk_size", default=8, min=1, max=4096), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], + ) + + @classmethod + def execute( + cls, + stream: Input.ImageStream, + fps: float, + chunk_size: int, + filename_prefix, + format: str, + codec, + audio: Optional[Input.Audio] = None, + ) -> io.NodeOutput: + saved_stream = _build_saved_stream( + cls.hidden, + stream, + audio, + fps, + filename_prefix, + format, + codec, + emit_preview=False, + ) + drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker()) + return io.NodeOutput(ui=saved_stream.get_preview_ui()) + + class CreateVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -262,6 +500,8 @@ class VideoExtension(ComfyExtension): return [ SaveWEBM, SaveVideo, + SavePassthroughVideoStream, + SaveVideoStream, CreateVideo, GetVideoComponents, LoadVideo, diff --git a/nodes.py b/nodes.py index fb83da896..64ad78528 100644 --- a/nodes.py +++ b/nodes.py @@ -2414,6 +2414,7 @@ async def init_builtin_extra_nodes(): "nodes_hooks.py", "nodes_load_3d.py", "nodes_cosmos.py", + "nodes_image_stream.py", "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py",