mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 13:22:32 +08:00
Merge 1c2d37944c into 5edbdf4364
This commit is contained in:
commit
2e0438bbd9
@ -51,6 +51,7 @@ class IO(StrEnum):
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
VIDEO = "VIDEO"
|
||||
IMAGE_STREAM = "IMAGE_STREAM"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
|
||||
@ -16,6 +16,17 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class RunUpState:
|
||||
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None):
|
||||
self.timestep_shift_scale = timestep_shift_scale
|
||||
self.scaled_timestep = scaled_timestep
|
||||
self.checkpoint_fn = checkpoint_fn
|
||||
self.max_chunk_size = max_chunk_size
|
||||
self.output_shape = output_shape
|
||||
self.output_dtype = output_dtype
|
||||
self.output_frames = output_frames
|
||||
self.pending_samples = []
|
||||
|
||||
def in_meta_context():
|
||||
return torch.device("meta") == torch.empty(0).device
|
||||
|
||||
@ -26,6 +37,14 @@ def mark_conv3d_ended(module):
|
||||
current = m.temporal_cache_state.get(tid, (None, False))
|
||||
m.temporal_cache_state[tid] = (current[0], True)
|
||||
|
||||
def clear_temporal_cache_state(module):
|
||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||
# we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
if hasattr(m, "temporal_cache_state"):
|
||||
m.temporal_cache_state.pop(tid, None)
|
||||
|
||||
def split2(tensor, split_point, dim=2):
|
||||
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||
|
||||
@ -315,13 +334,7 @@ class Encoder(nn.Module):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
tid = threading.get_ident()
|
||||
for _, module in self.named_modules():
|
||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||
# we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
clear_temporal_cache_state(self)
|
||||
|
||||
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
@ -530,19 +543,20 @@ class Decoder(nn.Module):
|
||||
).unsqueeze(1).expand(2, output_channel),
|
||||
persistent=False,
|
||||
)
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||
def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
if run_up_state.timestep_shift_scale is not None:
|
||||
shift, scale = run_up_state.timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
@ -550,38 +564,49 @@ class Decoder(nn.Module):
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
if output_buffer is None:
|
||||
run_up_state.output_frames = sample
|
||||
return
|
||||
output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]]
|
||||
t = output_slice.shape[2]
|
||||
output_slice.copy_(sample[:, :, :t])
|
||||
output_offset[0] += t
|
||||
if t < sample.shape[2]:
|
||||
run_up_state.output_frames = sample[:, :, t:]
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if ended:
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
sample = run_up_state.checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=run_up_state.scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size
|
||||
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
|
||||
#chunker or output frame stash will do the work anyway. so unchanged.
|
||||
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
samples = list(torch.chunk(sample, chunks=num_chunks, dim=2))
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
while len(samples):
|
||||
if output_buffer is None or output_offset[0] == output_buffer.shape[2]:
|
||||
run_up_state.pending_samples.append((idx + 1, samples, ended))
|
||||
return
|
||||
self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@ -591,6 +616,7 @@ class Decoder(nn.Module):
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
output_shape = self.decode_output_shape(sample.shape)
|
||||
|
||||
mark_conv3d_ended(self.conv_in)
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
@ -630,29 +656,89 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
output_offset = [0]
|
||||
|
||||
run_up_state = RunUpState(
|
||||
timestep_shift_scale=timestep_shift_scale,
|
||||
scaled_timestep=scaled_timestep,
|
||||
checkpoint_fn=checkpoint_fn,
|
||||
max_chunk_size=get_max_chunk_size(sample.device),
|
||||
output_shape=output_shape,
|
||||
output_dtype=sample.dtype,
|
||||
)
|
||||
self.temporal_cache_state[threading.get_ident()] = run_up_state
|
||||
|
||||
self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset)
|
||||
|
||||
return output_buffer
|
||||
|
||||
def forward_start(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
):
|
||||
try:
|
||||
return self.forward_orig(sample, timestep=timestep, output_buffer=None)
|
||||
except Exception:
|
||||
clear_temporal_cache_state(self)
|
||||
raise
|
||||
|
||||
def forward_resume(self, output_t: int):
|
||||
tid = threading.get_ident()
|
||||
run_up_state = self.temporal_cache_state.get(tid, None)
|
||||
if run_up_state is None:
|
||||
return None
|
||||
|
||||
output_shape = list(run_up_state.output_shape)
|
||||
output_shape[2] = output_t
|
||||
output_buffer = torch.empty(
|
||||
output_shape,
|
||||
dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
try:
|
||||
if run_up_state.output_frames is not None:
|
||||
output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]]
|
||||
t = output_slice.shape[2]
|
||||
output_slice.copy_(run_up_state.output_frames[:, :, :t])
|
||||
output_offset[0] += t
|
||||
run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:]
|
||||
|
||||
pending_samples = run_up_state.pending_samples
|
||||
run_up_state.pending_samples = []
|
||||
while len(pending_samples):
|
||||
idx, samples, ended = pending_samples.pop(0)
|
||||
while len(samples):
|
||||
if output_offset[0] == output_buffer.shape[2]:
|
||||
pending_samples = [(idx, samples, ended)] + pending_samples
|
||||
run_up_state.pending_samples.extend(pending_samples)
|
||||
return output_buffer
|
||||
sample1 = samples.pop(0)
|
||||
self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset)
|
||||
|
||||
if run_up_state.output_frames is None and not run_up_state.pending_samples:
|
||||
clear_temporal_cache_state(self)
|
||||
return output_buffer[:, :, :output_offset[0]]
|
||||
except Exception:
|
||||
clear_temporal_cache_state(self)
|
||||
raise
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
return output_buffer
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
|
||||
finally:
|
||||
for _, module in self.named_modules():
|
||||
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
||||
#we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
clear_temporal_cache_state(self)
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
@ -1302,6 +1388,15 @@ class VideoVAE(nn.Module):
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode_start(self, x):
|
||||
clear_temporal_cache_state(self.decoder)
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
def decode_chunk(self, output_t: int):
|
||||
return self.decoder.forward_resume(output_t)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -1014,6 +1014,17 @@ class VAE:
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_output_shape(self, samples_shape):
|
||||
self.throw_exception_if_invalid()
|
||||
if hasattr(self.first_stage_model, "decode_output_shape"):
|
||||
return self.first_stage_model.decode_output_shape(samples_shape)
|
||||
raise RuntimeError("This VAE does not expose decode output shape information.")
|
||||
|
||||
def decode_stream_start(self, samples_in):
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype))
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from comfy_api.latest._input import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
ImageStreamInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
@ -14,6 +15,7 @@ from comfy_api.latest._input import (
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"ImageStreamInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
|
||||
6
comfy_api/input/image_stream_types.py
Normal file
6
comfy_api/input/image_stream_types.py
Normal file
@ -0,0 +1,6 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.image_stream_types import ImageStreamInput
|
||||
|
||||
__all__ = [
|
||||
"ImageStreamInput",
|
||||
]
|
||||
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input import ImageInput, AudioInput, ImageStreamInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from . import _io_public as io
|
||||
@ -131,6 +131,7 @@ class ComfyExtension(ABC):
|
||||
class Input:
|
||||
Image = ImageInput
|
||||
Audio = AudioInput
|
||||
ImageStream = ImageStreamInput
|
||||
Mask = MaskInput
|
||||
Latent = LatentInput
|
||||
Video = VideoInput
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .image_stream_types import ImageStreamInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"ImageStreamInput",
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
|
||||
67
comfy_api/latest/_input/image_stream_types.py
Normal file
67
comfy_api/latest/_input/image_stream_types.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
|
||||
from comfy_execution.utils import CurrentNodeContext, get_executing_context
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from .basic_types import ImageInput
|
||||
|
||||
|
||||
class ImageStreamInput(ABC):
|
||||
"""Abstract base class for pull-based image stream inputs.
|
||||
|
||||
Consumers request up to ``max_frames`` frames at a time. Producers must not
|
||||
over-return; a batch with fewer than ``max_frames`` frames signals EOF.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
#Subclasses must call this init for future core ComfyUI change compatibilty
|
||||
self._ctx = get_executing_context()
|
||||
|
||||
def reset(self) -> None:
|
||||
#This API is final. Subclasses must NOT override this for future core ComfyUI
|
||||
#change compatability. Override do_reset instead.
|
||||
with (nullcontext() if self._ctx is None else
|
||||
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
|
||||
self.do_reset()
|
||||
|
||||
if self._ctx is not None:
|
||||
get_progress_state().finish_progress(self._ctx.node_id)
|
||||
|
||||
def pull(self, max_frames: int) -> ImageInput:
|
||||
#This API is final. Subclasses must NOT override this for future core ComfyUI
|
||||
#change compatability. Override do_pull instead.
|
||||
with (nullcontext() if self._ctx is None else
|
||||
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
|
||||
result = self.do_pull(max_frames)
|
||||
|
||||
if self._ctx is not None:
|
||||
registry = get_progress_state()
|
||||
entry = registry.nodes.get(self._ctx.node_id)
|
||||
if (int(result.shape[0]) < max_frames or
|
||||
(entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])):
|
||||
registry.finish_progress(self._ctx.node_id)
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Return the stream frame dimensions as ``(width, height)``."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_reset(self) -> None:
|
||||
"""Reset the stream so the next pull starts from frame 0."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_pull(self, max_frames: int) -> ImageInput:
|
||||
"""Return up to ``max_frames`` images.
|
||||
|
||||
The returned tensor uses the normal ``IMAGE`` batch shape. A short
|
||||
return, where the batch dimension is less than ``max_frames``, is the
|
||||
EOF signal. Sources are expected to short-return at least once before
|
||||
exhaustion, including returning an empty batch.
|
||||
"""
|
||||
pass
|
||||
@ -386,6 +386,7 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
self._frame_counter = 0
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
@ -394,14 +395,13 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def save_to(
|
||||
def save_start(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@ -413,7 +413,12 @@ class VideoFromComponents(VideoInput):
|
||||
# BytesIO has no file extension, so av.open can't infer the format.
|
||||
# Default to mp4 since that's the only supported format anyway.
|
||||
extra_kwargs["format"] = "mp4"
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||
|
||||
width, height = self.get_dimensions()
|
||||
|
||||
output = av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs)
|
||||
if True:
|
||||
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
@ -422,8 +427,8 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.width = width
|
||||
video_stream.height = height
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
@ -432,23 +437,33 @@ class VideoFromComponents(VideoInput):
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
waveform = waveform[0]
|
||||
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
|
||||
|
||||
self._frame_counter = 0
|
||||
return output, video_stream, audio_stream, audio_sample_rate, frame_rate
|
||||
|
||||
def save_add(self, output, video_stream, images) -> None:
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
for frame in images:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
self._frame_counter += 1
|
||||
|
||||
def save_finalize(self, output, video_stream, audio_stream, audio_sample_rate, frame_rate) -> None:
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[0]
|
||||
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||
waveform = waveform[:, :math.ceil((audio_sample_rate / frame_rate) * self._frame_counter)]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
@ -457,6 +472,29 @@ class VideoFromComponents(VideoInput):
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
output.close()
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
output, video_stream, audio_stream, audio_sample_rate, frame_rate = self.save_start(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata,
|
||||
)
|
||||
try:
|
||||
self.save_add(output, video_stream, self.__components.images)
|
||||
self.save_finalize(output, video_stream, audio_stream, audio_sample_rate, frame_rate)
|
||||
except Exception:
|
||||
output.close()
|
||||
raise
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
|
||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.input import ImageStreamInput, VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@ -420,6 +420,12 @@ class Image(ComfyTypeIO):
|
||||
Type = torch.Tensor
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGE_STREAM")
|
||||
class ImageStream(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = ImageStreamInput
|
||||
|
||||
|
||||
@comfytype(io_type="WAN_CAMERA_EMBEDDING")
|
||||
class WanCameraEmbedding(ComfyTypeIO):
|
||||
Type = torch.Tensor
|
||||
@ -2203,6 +2209,7 @@ __all__ = [
|
||||
"Combo",
|
||||
"MultiCombo",
|
||||
"Image",
|
||||
"ImageStream",
|
||||
"WanCameraEmbedding",
|
||||
"Webcam",
|
||||
"Mask",
|
||||
|
||||
320
comfy_extras/nodes_image_stream.py
Normal file
320
comfy_extras/nodes_image_stream.py
Normal file
@ -0,0 +1,320 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, io, ui
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class FrameProgressTracker:
|
||||
def __init__(self):
|
||||
self._last_reported=None
|
||||
|
||||
def emit(self, frames_processed):
|
||||
if frames_processed == self._last_reported:
|
||||
return
|
||||
|
||||
current = get_executing_context()
|
||||
server = getattr(PromptServer, "instance", None)
|
||||
if current is None or server is None or server.client_id is None:
|
||||
return
|
||||
|
||||
server.send_progress_text(
|
||||
f"processed {frames_processed} frames",
|
||||
current.node_id,
|
||||
server.client_id,
|
||||
)
|
||||
self._last_reported = frames_processed
|
||||
|
||||
|
||||
def drain_image_stream(stream, chunk_size, progress):
|
||||
stream.reset()
|
||||
frames_processed = 0
|
||||
|
||||
while True:
|
||||
progress.emit(frames_processed)
|
||||
chunk = stream.pull(chunk_size)
|
||||
frames_processed += int(chunk.shape[0])
|
||||
if chunk.shape[0] < chunk_size:
|
||||
progress.emit(frames_processed)
|
||||
return frames_processed
|
||||
|
||||
class TensorImageStream(Input.ImageStream):
|
||||
"""Simple IMAGE_STREAM backed by a materialized IMAGE batch tensor."""
|
||||
|
||||
def __init__(self, images: Input.Image):
|
||||
super().__init__()
|
||||
self._images = images
|
||||
self._index = 0
|
||||
self._total_frames = int(images.shape[0])
|
||||
self._progress_started = False
|
||||
|
||||
def _update_progress(self, value: int) -> None:
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
get_progress_state().update_progress(
|
||||
current.node_id,
|
||||
value=float(value),
|
||||
max_value=float(max(self._total_frames, 1)),
|
||||
)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._images.shape[2], self._images.shape[1]
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._index = 0
|
||||
self._progress_started = False
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
if not self._progress_started:
|
||||
self._update_progress(0)
|
||||
self._progress_started = True
|
||||
|
||||
start = self._index
|
||||
end = min(start + max_frames, self._images.shape[0])
|
||||
self._index = end
|
||||
chunk = self._images[start:end].clone()
|
||||
self._update_progress(end)
|
||||
return chunk
|
||||
|
||||
|
||||
class PreviewingImageStream(Input.ImageStream):
|
||||
def __init__(self, stream: Input.ImageStream):
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
|
||||
def _emit_preview(self, chunk: Input.Image) -> None:
|
||||
if int(chunk.shape[0]) == 0:
|
||||
return
|
||||
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
|
||||
server = getattr(PromptServer, "instance", None)
|
||||
if server is None or server.client_id is None:
|
||||
return
|
||||
|
||||
preview_output = ui.PreviewImage(chunk[-1:]).as_dict()
|
||||
server.send_sync(
|
||||
"executed",
|
||||
{
|
||||
"node": current.node_id,
|
||||
"display_node": current.node_id,
|
||||
"output": preview_output,
|
||||
"prompt_id": current.prompt_id,
|
||||
},
|
||||
server.client_id,
|
||||
)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._stream.get_dimensions()
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._stream.reset()
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
chunk = self._stream.pull(max_frames)
|
||||
self._emit_preview(chunk)
|
||||
return chunk
|
||||
|
||||
|
||||
class VAEDecodedImageStream(Input.ImageStream):
|
||||
def __init__(self, vae, latent: Input.Latent):
|
||||
super().__init__()
|
||||
self._vae = vae
|
||||
self._latent = latent
|
||||
vae.throw_exception_if_invalid()
|
||||
if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False):
|
||||
raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.")
|
||||
if latent.ndim != 5:
|
||||
raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].")
|
||||
if latent.shape[0] != 1:
|
||||
raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.")
|
||||
output_shape = vae.decode_output_shape(latent.shape)
|
||||
self._channels = int(output_shape[1])
|
||||
self._width = int(output_shape[4])
|
||||
self._height = int(output_shape[3])
|
||||
self._total_frames = int(output_shape[0] * output_shape[2])
|
||||
self._frames_emitted = 0
|
||||
|
||||
def _update_progress(self, value: int) -> None:
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
get_progress_state().update_progress(
|
||||
current.node_id,
|
||||
value=float(value),
|
||||
max_value=float(max(self._total_frames, 1)),
|
||||
)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._width, self._height
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._frames_emitted = 0
|
||||
self._update_progress(0)
|
||||
self._vae.decode_stream_start(self._latent)
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
chunk = self._vae.first_stage_model.decode_chunk(max_frames)
|
||||
if chunk is None:
|
||||
return torch.empty(
|
||||
(0, self._height, self._width, self._channels),
|
||||
device=self._vae.output_device,
|
||||
dtype=self._vae.vae_output_dtype(),
|
||||
)
|
||||
|
||||
chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype())
|
||||
chunk = self._vae.process_output(chunk).movedim(1, -1)
|
||||
chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:]))
|
||||
self._frames_emitted += int(chunk.shape[0])
|
||||
self._update_progress(self._frames_emitted)
|
||||
return chunk
|
||||
|
||||
|
||||
class ImageBatchToStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageBatchToStream",
|
||||
display_name="Image Batch To Stream",
|
||||
category="image/stream",
|
||||
search_aliases=["image to stream", "batch to stream", "frames to stream"],
|
||||
description="Wraps a batched IMAGE tensor as a pull-based IMAGE_STREAM.",
|
||||
inputs=[
|
||||
io.Image.Input("image", tooltip="A batched IMAGE tensor in BHWC format."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="stream"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: Input.Image) -> io.NodeOutput:
|
||||
return io.NodeOutput(TensorImageStream(image))
|
||||
|
||||
|
||||
class ImageStreamToBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageStreamToBatch",
|
||||
display_name="Image Stream To Batch",
|
||||
category="image/stream",
|
||||
search_aliases=["stream to image", "stream to batch", "collect stream"],
|
||||
description="Materializes an IMAGE_STREAM back into a batched IMAGE tensor.",
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="A pull-based IMAGE_STREAM."),
|
||||
io.Int.Input("batch_size", default=4096, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, stream: Input.ImageStream, batch_size: int) -> io.NodeOutput:
|
||||
chunks: list[Input.Image] = []
|
||||
stream.reset()
|
||||
|
||||
while True:
|
||||
chunk = stream.pull(batch_size)
|
||||
|
||||
chunks.append(chunk)
|
||||
if chunk.shape[0] < batch_size:
|
||||
break
|
||||
|
||||
return io.NodeOutput(torch.cat(chunks, dim=0))
|
||||
|
||||
|
||||
class VAEDecodeStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VAEDecodeStream",
|
||||
display_name="VAE Decode Stream",
|
||||
category="image/stream",
|
||||
search_aliases=["vae stream decode", "latent to stream", "video latent stream"],
|
||||
description="Decodes a latent into an IMAGE_STREAM.",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The LTX latent to decode."),
|
||||
io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="stream"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput:
|
||||
latent = samples["samples"]
|
||||
if latent.is_nested:
|
||||
latent = latent.unbind()[0]
|
||||
return io.NodeOutput(VAEDecodedImageStream(vae, latent))
|
||||
|
||||
|
||||
class PreviewImageStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PreviewImageStream",
|
||||
display_name="Preview Image Stream",
|
||||
category="image/stream",
|
||||
search_aliases=["stream preview", "preview frames", "preview image stream"],
|
||||
description="Passes an IMAGE_STREAM through while previewing the last frame from each pulled chunk.",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to preview inline."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="passthrough"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, stream: Input.ImageStream) -> io.NodeOutput:
|
||||
return io.NodeOutput(PreviewingImageStream(stream))
|
||||
|
||||
|
||||
class StreamSink(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StreamSink",
|
||||
search_aliases=["consume stream", "drain stream", "image stream sink"],
|
||||
display_name="Stream Sink",
|
||||
category="image/stream",
|
||||
description="Consumes an IMAGE_STREAM by pulling it to EOF.",
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to consume."),
|
||||
io.Int.Input("chunk_size", default=8, min=1, max=4096),
|
||||
],
|
||||
outputs=[],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, stream: Input.ImageStream, chunk_size: int) -> io.NodeOutput:
|
||||
drain_image_stream(stream, chunk_size, progress=FrameProgressTracker())
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
class ImageStreamExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ImageBatchToStream,
|
||||
ImageStreamToBatch,
|
||||
VAEDecodeStream,
|
||||
PreviewImageStream,
|
||||
StreamSink,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ImageStreamExtension:
|
||||
return ImageStreamExtension()
|
||||
@ -5,8 +5,7 @@ import torch
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.model_management
|
||||
from comfy_api.latest import ComfyExtension, Input, io
|
||||
|
||||
try:
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||
@ -47,9 +46,29 @@ class UpscaleModelLoader(io.ComfyNode):
|
||||
load_model = execute # TODO: remove
|
||||
|
||||
|
||||
class UpscaledImageStream(Input.ImageStream):
|
||||
def __init__(self, upscale_model, stream: Input.ImageStream):
|
||||
super().__init__()
|
||||
self._upscale_model = upscale_model
|
||||
self._stream = stream
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
width, height = self._stream.get_dimensions()
|
||||
scale = self._upscale_model.scale
|
||||
return int(width * scale), int(height * scale)
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._stream.reset()
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
chunk = self._stream.pull(max_frames)
|
||||
return ImageUpscaleWithModel.upscale_batch(self._upscale_model, chunk)
|
||||
|
||||
|
||||
class ImageUpscaleWithModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
image_template = io.MatchType.Template("image_type", allowed_types=[io.Image, io.ImageStream])
|
||||
return io.Schema(
|
||||
node_id="ImageUpscaleWithModel",
|
||||
display_name="Upscale Image (using Model)",
|
||||
@ -57,15 +76,18 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
|
||||
inputs=[
|
||||
io.UpscaleModel.Input("upscale_model"),
|
||||
io.Image.Input("image"),
|
||||
io.MatchType.Input("image", template=image_template),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
io.MatchType.Output(template=image_template, display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, upscale_model, image) -> io.NodeOutput:
|
||||
def upscale_batch(cls, upscale_model, image: torch.Tensor) -> torch.Tensor:
|
||||
if image.shape[0] == 0:
|
||||
return image.clone()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
|
||||
memory_required = model_management.module_size(upscale_model.model)
|
||||
@ -79,7 +101,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
tile = 512
|
||||
overlap = 32
|
||||
|
||||
output_device = comfy.model_management.intermediate_device()
|
||||
output_device = model_management.intermediate_device()
|
||||
|
||||
oom = True
|
||||
try:
|
||||
@ -97,8 +119,14 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput(s)
|
||||
return torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(model_management.intermediate_dtype())
|
||||
|
||||
@classmethod
|
||||
def execute(cls, upscale_model, image) -> io.NodeOutput:
|
||||
if isinstance(image, torch.Tensor):
|
||||
return io.NodeOutput(cls.upscale_batch(upscale_model, image))
|
||||
|
||||
return io.NodeOutput(UpscaledImageStream(upscale_model, image))
|
||||
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
|
||||
@ -5,11 +5,162 @@ import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class SavedVideoStream(Input.ImageStream):
|
||||
def __init__(
|
||||
self,
|
||||
stream: Input.ImageStream,
|
||||
saver,
|
||||
output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]],
|
||||
format: str,
|
||||
codec,
|
||||
metadata: Optional[dict],
|
||||
emit_preview_on_finalize: bool = True,
|
||||
preview_node_id: Optional[str] = None,
|
||||
preview_display_node_id: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
self._saver = saver
|
||||
self._output_factory = output_factory
|
||||
self._path: str | None = None
|
||||
self._format = Types.VideoContainer(format)
|
||||
self._codec = Types.VideoCodec(codec)
|
||||
self._metadata = metadata
|
||||
self._preview_ui: ui.PreviewVideo | None = None
|
||||
self._emit_preview_on_finalize = emit_preview_on_finalize
|
||||
self._preview_node_id = preview_node_id
|
||||
self._preview_display_node_id = preview_display_node_id
|
||||
self._save_state = None
|
||||
|
||||
def _emit_preview(self) -> None:
|
||||
if not self._emit_preview_on_finalize or self._preview_ui is None:
|
||||
return
|
||||
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
|
||||
server = getattr(PromptServer, "instance", None)
|
||||
if server is None or server.client_id is None:
|
||||
return
|
||||
|
||||
server.send_sync(
|
||||
"executed",
|
||||
{
|
||||
"node": self._preview_node_id or current.node_id,
|
||||
"display_node": self._preview_display_node_id or self._preview_node_id or current.node_id,
|
||||
"output": self._preview_ui.as_dict(),
|
||||
"prompt_id": current.prompt_id,
|
||||
},
|
||||
server.client_id,
|
||||
)
|
||||
|
||||
def _discard_partial_output(self) -> None:
|
||||
if self._save_state is not None:
|
||||
self._save_state[0].close()
|
||||
self._save_state = None
|
||||
if self._path is not None and os.path.exists(self._path):
|
||||
os.remove(self._path)
|
||||
self._path = None
|
||||
self._preview_ui = None
|
||||
|
||||
def get_preview_ui(self) -> ui.PreviewVideo | None:
|
||||
return self._preview_ui
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._stream.get_dimensions()
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._discard_partial_output()
|
||||
self._stream.reset()
|
||||
self._path, self._preview_ui = self._output_factory()
|
||||
assert self._path is not None
|
||||
open(self._path, "ab").close()
|
||||
self._save_state = self._saver.save_start(
|
||||
self._path,
|
||||
format=self._format,
|
||||
codec=self._codec,
|
||||
metadata=self._metadata,
|
||||
)
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
assert self._save_state is not None
|
||||
chunk = self._stream.pull(max_frames)
|
||||
self._saver.save_add(self._save_state[0], self._save_state[1], chunk)
|
||||
if chunk.shape[0] < max_frames:
|
||||
self._saver.save_finalize(*self._save_state)
|
||||
self._save_state = None
|
||||
self._emit_preview()
|
||||
self._path = None
|
||||
return chunk
|
||||
|
||||
|
||||
def _build_saved_stream(
|
||||
hidden,
|
||||
stream: Input.ImageStream,
|
||||
audio: Optional[Input.Audio],
|
||||
fps: float,
|
||||
filename_prefix,
|
||||
format: str,
|
||||
codec,
|
||||
emit_preview: bool = True,
|
||||
) -> SavedVideoStream:
|
||||
width, height = stream.get_dimensions()
|
||||
saved_metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = {}
|
||||
if hidden.extra_pnginfo is not None:
|
||||
metadata.update(hidden.extra_pnginfo)
|
||||
if hidden.prompt is not None:
|
||||
metadata["prompt"] = hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
|
||||
preview_node_id = hidden.unique_id
|
||||
preview_display_node_id = preview_node_id
|
||||
if hidden.dynprompt is not None and preview_node_id is not None:
|
||||
preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id)
|
||||
|
||||
def output_factory() -> tuple[str, ui.PreviewVideo]:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
folder_paths.get_output_directory(),
|
||||
width,
|
||||
height,
|
||||
)
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
return (
|
||||
os.path.join(full_output_folder, file),
|
||||
ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]),
|
||||
)
|
||||
|
||||
return SavedVideoStream(
|
||||
stream,
|
||||
InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
images=torch.zeros((0, height, width, 3)),
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
)
|
||||
),
|
||||
output_factory,
|
||||
format,
|
||||
codec,
|
||||
saved_metadata,
|
||||
emit_preview,
|
||||
preview_node_id,
|
||||
preview_display_node_id,
|
||||
)
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -114,6 +265,93 @@ class SaveVideo(io.ComfyNode):
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
|
||||
|
||||
class SavePassthroughVideoStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SavePassthroughVideoStream",
|
||||
search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"],
|
||||
display_name="Save+Passthrough Video Stream",
|
||||
category="image/video",
|
||||
essentials_category="Basics",
|
||||
description="Saves frames as they pass through the input image stream.",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to save."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="passthrough"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
stream: Input.ImageStream,
|
||||
fps: float,
|
||||
filename_prefix,
|
||||
format: str,
|
||||
codec,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
) -> io.NodeOutput:
|
||||
return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec))
|
||||
|
||||
|
||||
class SaveVideoStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideoStream",
|
||||
search_aliases=["save image stream", "export video stream", "stream to video"],
|
||||
display_name="Save Video Stream",
|
||||
category="image/video",
|
||||
essentials_category="Basics",
|
||||
description="Saves an image stream by draining it directly to EOF.",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to save."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Int.Input("chunk_size", default=8, min=1, max=4096),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
stream: Input.ImageStream,
|
||||
fps: float,
|
||||
chunk_size: int,
|
||||
filename_prefix,
|
||||
format: str,
|
||||
codec,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
) -> io.NodeOutput:
|
||||
saved_stream = _build_saved_stream(
|
||||
cls.hidden,
|
||||
stream,
|
||||
audio,
|
||||
fps,
|
||||
filename_prefix,
|
||||
format,
|
||||
codec,
|
||||
emit_preview=False,
|
||||
)
|
||||
drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker())
|
||||
return io.NodeOutput(ui=saved_stream.get_preview_ui())
|
||||
|
||||
|
||||
class CreateVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -262,6 +500,8 @@ class VideoExtension(ComfyExtension):
|
||||
return [
|
||||
SaveWEBM,
|
||||
SaveVideo,
|
||||
SavePassthroughVideoStream,
|
||||
SaveVideoStream,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user