mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
513 lines
20 KiB
Python
513 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import av
|
|
import torch
|
|
import folder_paths
|
|
import json
|
|
from typing import Callable, Optional
|
|
from typing_extensions import override
|
|
from fractions import Fraction
|
|
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
|
from comfy.cli_args import args
|
|
from comfy_execution.utils import get_executing_context
|
|
from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream
|
|
from server import PromptServer
|
|
|
|
|
|
class SavedVideoStream(Input.ImageStream):
|
|
def __init__(
|
|
self,
|
|
stream: Input.ImageStream,
|
|
saver,
|
|
output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]],
|
|
format: str,
|
|
codec,
|
|
metadata: Optional[dict],
|
|
emit_preview_on_finalize: bool = True,
|
|
preview_node_id: Optional[str] = None,
|
|
preview_display_node_id: Optional[str] = None,
|
|
):
|
|
super().__init__()
|
|
self._stream = stream
|
|
self._saver = saver
|
|
self._output_factory = output_factory
|
|
self._path: str | None = None
|
|
self._format = Types.VideoContainer(format)
|
|
self._codec = Types.VideoCodec(codec)
|
|
self._metadata = metadata
|
|
self._preview_ui: ui.PreviewVideo | None = None
|
|
self._emit_preview_on_finalize = emit_preview_on_finalize
|
|
self._preview_node_id = preview_node_id
|
|
self._preview_display_node_id = preview_display_node_id
|
|
self._save_state = None
|
|
|
|
def _emit_preview(self) -> None:
|
|
if not self._emit_preview_on_finalize or self._preview_ui is None:
|
|
return
|
|
|
|
current = get_executing_context()
|
|
if current is None:
|
|
return
|
|
|
|
server = getattr(PromptServer, "instance", None)
|
|
if server is None or server.client_id is None:
|
|
return
|
|
|
|
server.send_sync(
|
|
"executed",
|
|
{
|
|
"node": self._preview_node_id or current.node_id,
|
|
"display_node": self._preview_display_node_id or self._preview_node_id or current.node_id,
|
|
"output": self._preview_ui.as_dict(),
|
|
"prompt_id": current.prompt_id,
|
|
},
|
|
server.client_id,
|
|
)
|
|
|
|
def _discard_partial_output(self) -> None:
|
|
if self._save_state is not None:
|
|
self._save_state[0].close()
|
|
self._save_state = None
|
|
if self._path is not None and os.path.exists(self._path):
|
|
os.remove(self._path)
|
|
self._path = None
|
|
self._preview_ui = None
|
|
|
|
def get_preview_ui(self) -> ui.PreviewVideo | None:
|
|
return self._preview_ui
|
|
|
|
def get_dimensions(self) -> tuple[int, int]:
|
|
return self._stream.get_dimensions()
|
|
|
|
def do_reset(self) -> None:
|
|
self._discard_partial_output()
|
|
self._stream.reset()
|
|
self._path, self._preview_ui = self._output_factory()
|
|
assert self._path is not None
|
|
open(self._path, "ab").close()
|
|
self._save_state = self._saver.save_start(
|
|
self._path,
|
|
format=self._format,
|
|
codec=self._codec,
|
|
metadata=self._metadata,
|
|
)
|
|
|
|
def do_pull(self, max_frames: int) -> Input.Image:
|
|
assert self._save_state is not None
|
|
chunk = self._stream.pull(max_frames)
|
|
self._saver.save_add(self._save_state[0], self._save_state[1], chunk)
|
|
if chunk.shape[0] < max_frames:
|
|
self._saver.save_finalize(*self._save_state)
|
|
self._save_state = None
|
|
self._emit_preview()
|
|
self._path = None
|
|
return chunk
|
|
|
|
|
|
def _build_saved_stream(
|
|
hidden,
|
|
stream: Input.ImageStream,
|
|
audio: Optional[Input.Audio],
|
|
fps: float,
|
|
filename_prefix,
|
|
format: str,
|
|
codec,
|
|
emit_preview: bool = True,
|
|
) -> SavedVideoStream:
|
|
width, height = stream.get_dimensions()
|
|
saved_metadata = None
|
|
if not args.disable_metadata:
|
|
metadata = {}
|
|
if hidden.extra_pnginfo is not None:
|
|
metadata.update(hidden.extra_pnginfo)
|
|
if hidden.prompt is not None:
|
|
metadata["prompt"] = hidden.prompt
|
|
if len(metadata) > 0:
|
|
saved_metadata = metadata
|
|
|
|
preview_node_id = hidden.unique_id
|
|
preview_display_node_id = preview_node_id
|
|
if hidden.dynprompt is not None and preview_node_id is not None:
|
|
preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id)
|
|
|
|
def output_factory() -> tuple[str, ui.PreviewVideo]:
|
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
|
filename_prefix,
|
|
folder_paths.get_output_directory(),
|
|
width,
|
|
height,
|
|
)
|
|
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
|
return (
|
|
os.path.join(full_output_folder, file),
|
|
ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]),
|
|
)
|
|
|
|
return SavedVideoStream(
|
|
stream,
|
|
InputImpl.VideoFromComponents(
|
|
Types.VideoComponents(
|
|
images=torch.zeros((0, height, width, 3)),
|
|
audio=audio,
|
|
frame_rate=Fraction(fps),
|
|
)
|
|
),
|
|
output_factory,
|
|
format,
|
|
codec,
|
|
saved_metadata,
|
|
emit_preview,
|
|
preview_node_id,
|
|
preview_display_node_id,
|
|
)
|
|
|
|
class SaveWEBM(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveWEBM",
|
|
search_aliases=["export webm"],
|
|
category="image/video",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Image.Input("images"),
|
|
io.String.Input("filename_prefix", default="ComfyUI"),
|
|
io.Combo.Input("codec", options=["vp9", "av1"]),
|
|
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
|
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
|
],
|
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
|
is_output_node=True,
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
|
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
|
|
)
|
|
|
|
file = f"{filename}_{counter:05}_.webm"
|
|
container = av.open(os.path.join(full_output_folder, file), mode="w")
|
|
|
|
if cls.hidden.prompt is not None:
|
|
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
|
|
|
if cls.hidden.extra_pnginfo is not None:
|
|
for x in cls.hidden.extra_pnginfo:
|
|
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
|
|
|
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
|
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
|
stream.width = images.shape[-2]
|
|
stream.height = images.shape[-3]
|
|
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
|
stream.bit_rate = 0
|
|
stream.options = {'crf': str(crf)}
|
|
if codec == "av1":
|
|
stream.options["preset"] = "6"
|
|
|
|
for frame in images:
|
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
|
for packet in stream.encode(frame):
|
|
container.mux(packet)
|
|
container.mux(stream.encode())
|
|
container.close()
|
|
|
|
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
|
|
|
class SaveVideo(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveVideo",
|
|
search_aliases=["export video"],
|
|
display_name="Save Video",
|
|
category="image/video",
|
|
essentials_category="Basics",
|
|
description="Saves the input images to your ComfyUI output directory.",
|
|
inputs=[
|
|
io.Video.Input("video", tooltip="The video to save."),
|
|
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
|
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
|
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
|
],
|
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
|
is_output_node=True,
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
|
width, height = video.get_dimensions()
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
|
filename_prefix,
|
|
folder_paths.get_output_directory(),
|
|
width,
|
|
height
|
|
)
|
|
saved_metadata = None
|
|
if not args.disable_metadata:
|
|
metadata = {}
|
|
if cls.hidden.extra_pnginfo is not None:
|
|
metadata.update(cls.hidden.extra_pnginfo)
|
|
if cls.hidden.prompt is not None:
|
|
metadata["prompt"] = cls.hidden.prompt
|
|
if len(metadata) > 0:
|
|
saved_metadata = metadata
|
|
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
|
video.save_to(
|
|
os.path.join(full_output_folder, file),
|
|
format=Types.VideoContainer(format),
|
|
codec=codec,
|
|
metadata=saved_metadata
|
|
)
|
|
|
|
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
|
|
|
|
|
class SavePassthroughVideoStream(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SavePassthroughVideoStream",
|
|
search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"],
|
|
display_name="Save+Passthrough Video Stream",
|
|
category="image/video",
|
|
essentials_category="Basics",
|
|
description="Saves frames as they pass through the input image stream.",
|
|
has_intermediate_output=True,
|
|
inputs=[
|
|
io.ImageStream.Input("stream", tooltip="The image stream to save."),
|
|
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
|
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
|
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
|
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
|
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
|
],
|
|
outputs=[
|
|
io.ImageStream.Output(display_name="passthrough"),
|
|
],
|
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
stream: Input.ImageStream,
|
|
audio: Optional[Input.Audio],
|
|
fps: float,
|
|
filename_prefix,
|
|
format: str,
|
|
codec,
|
|
) -> io.NodeOutput:
|
|
return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec))
|
|
|
|
|
|
class SaveVideoStream(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveVideoStream",
|
|
search_aliases=["save image stream", "export video stream", "stream to video"],
|
|
display_name="Save Video Stream",
|
|
category="image/video",
|
|
essentials_category="Basics",
|
|
description="Saves an image stream by draining it directly to EOF.",
|
|
is_output_node=True,
|
|
inputs=[
|
|
io.ImageStream.Input("stream", tooltip="The image stream to save."),
|
|
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
|
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
|
io.Int.Input("chunk_size", default=8, min=1, max=4096),
|
|
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
|
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
|
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
|
],
|
|
outputs=[],
|
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
stream: Input.ImageStream,
|
|
audio: Optional[Input.Audio],
|
|
fps: float,
|
|
chunk_size: int,
|
|
filename_prefix,
|
|
format: str,
|
|
codec,
|
|
) -> io.NodeOutput:
|
|
saved_stream = _build_saved_stream(
|
|
cls.hidden,
|
|
stream,
|
|
audio,
|
|
fps,
|
|
filename_prefix,
|
|
format,
|
|
codec,
|
|
emit_preview=False,
|
|
)
|
|
drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker())
|
|
return io.NodeOutput(ui=saved_stream.get_preview_ui())
|
|
|
|
|
|
class CreateVideo(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="CreateVideo",
|
|
search_aliases=["images to video"],
|
|
display_name="Create Video",
|
|
category="image/video",
|
|
description="Create a video from images.",
|
|
inputs=[
|
|
io.Image.Input("images", tooltip="The images to create a video from."),
|
|
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
|
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
|
],
|
|
outputs=[
|
|
io.Video.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
|
return io.NodeOutput(
|
|
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
|
)
|
|
|
|
class GetVideoComponents(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="GetVideoComponents",
|
|
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
|
display_name="Get Video Components",
|
|
category="image/video",
|
|
description="Extracts all components from a video: frames, audio, and framerate.",
|
|
inputs=[
|
|
io.Video.Input("video", tooltip="The video to extract components from."),
|
|
],
|
|
outputs=[
|
|
io.Image.Output(display_name="images"),
|
|
io.Audio.Output(display_name="audio"),
|
|
io.Float.Output(display_name="fps"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
|
components = video.get_components()
|
|
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
|
|
|
|
|
class LoadVideo(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
input_dir = folder_paths.get_input_directory()
|
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
|
files = folder_paths.filter_files_content_types(files, ["video"])
|
|
return io.Schema(
|
|
node_id="LoadVideo",
|
|
search_aliases=["import video", "open video", "video file"],
|
|
display_name="Load Video",
|
|
category="image/video",
|
|
essentials_category="Basics",
|
|
inputs=[
|
|
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
|
],
|
|
outputs=[
|
|
io.Video.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, file) -> io.NodeOutput:
|
|
video_path = folder_paths.get_annotated_filepath(file)
|
|
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
|
|
|
|
@classmethod
|
|
def fingerprint_inputs(s, file):
|
|
video_path = folder_paths.get_annotated_filepath(file)
|
|
mod_time = os.path.getmtime(video_path)
|
|
# Instead of hashing the file, we can just use the modification time to avoid
|
|
# rehashing large files.
|
|
return mod_time
|
|
|
|
@classmethod
|
|
def validate_inputs(s, file):
|
|
if not folder_paths.exists_annotated_filepath(file):
|
|
return "Invalid video file: {}".format(file)
|
|
|
|
return True
|
|
|
|
class VideoSlice(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="Video Slice",
|
|
display_name="Video Slice",
|
|
search_aliases=[
|
|
"trim video duration",
|
|
"skip first frames",
|
|
"frame load cap",
|
|
"start time",
|
|
],
|
|
category="image/video",
|
|
essentials_category="Video Tools",
|
|
inputs=[
|
|
io.Video.Input("video"),
|
|
io.Float.Input(
|
|
"start_time",
|
|
default=0.0,
|
|
max=1e5,
|
|
min=-1e5,
|
|
step=0.001,
|
|
tooltip="Start time in seconds",
|
|
),
|
|
io.Float.Input(
|
|
"duration",
|
|
default=0.0,
|
|
min=0.0,
|
|
step=0.001,
|
|
tooltip="Duration in seconds, or 0 for unlimited duration",
|
|
),
|
|
io.Boolean.Input(
|
|
"strict_duration",
|
|
default=False,
|
|
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Video.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
|
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
|
if trimmed is not None:
|
|
return io.NodeOutput(trimmed)
|
|
raise ValueError(
|
|
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
|
)
|
|
|
|
|
|
class VideoExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
SaveWEBM,
|
|
SaveVideo,
|
|
SavePassthroughVideoStream,
|
|
SaveVideoStream,
|
|
CreateVideo,
|
|
GetVideoComponents,
|
|
LoadVideo,
|
|
VideoSlice,
|
|
]
|
|
|
|
async def comfy_entrypoint() -> VideoExtension:
|
|
return VideoExtension()
|