mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
nodes_video: implement stream saver nodes
This commit is contained in:
parent
05b5240d70
commit
3d21d2afb6
@ -5,11 +5,162 @@ import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class SavedVideoStream(Input.ImageStream):
|
||||
def __init__(
|
||||
self,
|
||||
stream: Input.ImageStream,
|
||||
saver,
|
||||
output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]],
|
||||
format: str,
|
||||
codec,
|
||||
metadata: Optional[dict],
|
||||
emit_preview_on_finalize: bool = True,
|
||||
preview_node_id: Optional[str] = None,
|
||||
preview_display_node_id: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
self._saver = saver
|
||||
self._output_factory = output_factory
|
||||
self._path: str | None = None
|
||||
self._format = Types.VideoContainer(format)
|
||||
self._codec = Types.VideoCodec(codec)
|
||||
self._metadata = metadata
|
||||
self._preview_ui: ui.PreviewVideo | None = None
|
||||
self._emit_preview_on_finalize = emit_preview_on_finalize
|
||||
self._preview_node_id = preview_node_id
|
||||
self._preview_display_node_id = preview_display_node_id
|
||||
self._save_state = None
|
||||
|
||||
def _emit_preview(self) -> None:
|
||||
if not self._emit_preview_on_finalize or self._preview_ui is None:
|
||||
return
|
||||
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
|
||||
server = getattr(PromptServer, "instance", None)
|
||||
if server is None or server.client_id is None:
|
||||
return
|
||||
|
||||
server.send_sync(
|
||||
"executed",
|
||||
{
|
||||
"node": self._preview_node_id or current.node_id,
|
||||
"display_node": self._preview_display_node_id or self._preview_node_id or current.node_id,
|
||||
"output": self._preview_ui.as_dict(),
|
||||
"prompt_id": current.prompt_id,
|
||||
},
|
||||
server.client_id,
|
||||
)
|
||||
|
||||
def _discard_partial_output(self) -> None:
|
||||
if self._save_state is not None:
|
||||
self._save_state[0].close()
|
||||
self._save_state = None
|
||||
if self._path is not None and os.path.exists(self._path):
|
||||
os.remove(self._path)
|
||||
self._path = None
|
||||
self._preview_ui = None
|
||||
|
||||
def get_preview_ui(self) -> ui.PreviewVideo | None:
|
||||
return self._preview_ui
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._stream.get_dimensions()
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._discard_partial_output()
|
||||
self._stream.reset()
|
||||
self._path, self._preview_ui = self._output_factory()
|
||||
assert self._path is not None
|
||||
open(self._path, "ab").close()
|
||||
self._save_state = self._saver.save_start(
|
||||
self._path,
|
||||
format=self._format,
|
||||
codec=self._codec,
|
||||
metadata=self._metadata,
|
||||
)
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
assert self._save_state is not None
|
||||
chunk = self._stream.pull(max_frames)
|
||||
self._saver.save_add(self._save_state[0], self._save_state[1], chunk)
|
||||
if chunk.shape[0] < max_frames:
|
||||
self._saver.save_finalize(*self._save_state)
|
||||
self._save_state = None
|
||||
self._emit_preview()
|
||||
self._path = None
|
||||
return chunk
|
||||
|
||||
|
||||
def _build_saved_stream(
|
||||
hidden,
|
||||
stream: Input.ImageStream,
|
||||
audio: Optional[Input.Audio],
|
||||
fps: float,
|
||||
filename_prefix,
|
||||
format: str,
|
||||
codec,
|
||||
emit_preview: bool = True,
|
||||
) -> SavedVideoStream:
|
||||
width, height = stream.get_dimensions()
|
||||
saved_metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = {}
|
||||
if hidden.extra_pnginfo is not None:
|
||||
metadata.update(hidden.extra_pnginfo)
|
||||
if hidden.prompt is not None:
|
||||
metadata["prompt"] = hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
|
||||
preview_node_id = hidden.unique_id
|
||||
preview_display_node_id = preview_node_id
|
||||
if hidden.dynprompt is not None and preview_node_id is not None:
|
||||
preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id)
|
||||
|
||||
def output_factory() -> tuple[str, ui.PreviewVideo]:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
folder_paths.get_output_directory(),
|
||||
width,
|
||||
height,
|
||||
)
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
return (
|
||||
os.path.join(full_output_folder, file),
|
||||
ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]),
|
||||
)
|
||||
|
||||
return SavedVideoStream(
|
||||
stream,
|
||||
InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
images=torch.zeros((0, height, width, 3)),
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
)
|
||||
),
|
||||
output_factory,
|
||||
format,
|
||||
codec,
|
||||
saved_metadata,
|
||||
emit_preview,
|
||||
preview_node_id,
|
||||
preview_display_node_id,
|
||||
)
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -114,6 +265,93 @@ class SaveVideo(io.ComfyNode):
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
|
||||
|
||||
class SavePassthroughVideoStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SavePassthroughVideoStream",
|
||||
search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"],
|
||||
display_name="Save+Passthrough Video Stream",
|
||||
category="image/video",
|
||||
essentials_category="Basics",
|
||||
description="Saves frames as they pass through the input image stream.",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to save."),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="passthrough"),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
stream: Input.ImageStream,
|
||||
audio: Optional[Input.Audio],
|
||||
fps: float,
|
||||
filename_prefix,
|
||||
format: str,
|
||||
codec,
|
||||
) -> io.NodeOutput:
|
||||
return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec))
|
||||
|
||||
|
||||
class SaveVideoStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideoStream",
|
||||
search_aliases=["save image stream", "export video stream", "stream to video"],
|
||||
display_name="Save Video Stream",
|
||||
category="image/video",
|
||||
essentials_category="Basics",
|
||||
description="Saves an image stream by draining it directly to EOF.",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.ImageStream.Input("stream", tooltip="The image stream to save."),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Int.Input("chunk_size", default=8, min=1, max=4096),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
stream: Input.ImageStream,
|
||||
audio: Optional[Input.Audio],
|
||||
fps: float,
|
||||
chunk_size: int,
|
||||
filename_prefix,
|
||||
format: str,
|
||||
codec,
|
||||
) -> io.NodeOutput:
|
||||
saved_stream = _build_saved_stream(
|
||||
cls.hidden,
|
||||
stream,
|
||||
audio,
|
||||
fps,
|
||||
filename_prefix,
|
||||
format,
|
||||
codec,
|
||||
emit_preview=False,
|
||||
)
|
||||
drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker())
|
||||
return io.NodeOutput(ui=saved_stream.get_preview_ui())
|
||||
|
||||
|
||||
class CreateVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -262,6 +500,8 @@ class VideoExtension(ComfyExtension):
|
||||
return [
|
||||
SaveWEBM,
|
||||
SaveVideo,
|
||||
SavePassthroughVideoStream,
|
||||
SaveVideoStream,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user