mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Merge branch 'master' into fix-video-decode-alignment
This commit is contained in:
commit
e6077d86e9
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from comfy.ldm.lightricks.model import Timesteps
|
from comfy.ldm.lightricks.model import Timesteps
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis):
|
|||||||
if x.shape[1] == 0:
|
if x.shape[1] == 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
return apply_rope1(x, freqs_cis)
|
||||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
|
||||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -27,10 +27,13 @@ class VideoInput(ABC):
|
|||||||
path: Union[str, IO[bytes]],
|
path: Union[str, IO[bytes]],
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Abstract method to save the video input to a file.
|
Abstract method to save the video input to a file.
|
||||||
|
|
||||||
|
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -83,6 +86,14 @@ class VideoInput(ABC):
|
|||||||
components = self.get_components()
|
components = self.get_components()
|
||||||
return components.images.shape[2], components.images.shape[1]
|
return components.images.shape[2], components.images.shape[1]
|
||||||
|
|
||||||
|
def get_bit_depth(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the bit depth of the video (e.g. 8 or 10).
|
||||||
|
|
||||||
|
Default implementation returns 8; subclasses report their real depth.
|
||||||
|
"""
|
||||||
|
return 8
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
def get_duration(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns the duration of the video in seconds.
|
Returns the duration of the video in seconds.
|
||||||
|
|||||||
@ -52,6 +52,12 @@ def get_open_write_kwargs(
|
|||||||
return open_kwargs
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def video_stream_bit_depth(stream) -> int:
|
||||||
|
if stream is None or stream.format is None or not stream.format.components:
|
||||||
|
return 8
|
||||||
|
return max(component.bits for component in stream.format.components)
|
||||||
|
|
||||||
|
|
||||||
class VideoFromFile(VideoInput):
|
class VideoFromFile(VideoInput):
|
||||||
"""
|
"""
|
||||||
Class representing video input from a file.
|
Class representing video input from a file.
|
||||||
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
|
|||||||
return stream.width, stream.height
|
return stream.width, stream.height
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_bit_depth(self) -> int:
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||||
|
return video_stream_bit_depth(video_stream)
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
def get_duration(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns the duration of the video in seconds.
|
Returns the duration of the video in seconds.
|
||||||
@ -393,25 +406,32 @@ class VideoFromFile(VideoInput):
|
|||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: int | None = None,
|
||||||
):
|
):
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
with av.open(self.__file, mode='r') as container:
|
with av.open(self.__file, mode='r') as container:
|
||||||
container_format = container.format.name
|
container_format = container.format.name
|
||||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||||
|
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||||
|
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||||
reuse_streams = True
|
reuse_streams = True
|
||||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||||
|
reuse_streams = False
|
||||||
if self.__start_time or self.__duration:
|
if self.__start_time or self.__duration:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
|
||||||
if not reuse_streams:
|
if not reuse_streams:
|
||||||
|
if bit_depth is None:
|
||||||
|
bit_depth = source_bit_depth
|
||||||
components = self.get_components_internal(container)
|
components = self.get_components_internal(container)
|
||||||
video = VideoFromComponents(components)
|
video = VideoFromComponents(components)
|
||||||
return video.save_to(
|
return video.save_to(
|
||||||
path, format=format, codec=codec, metadata=metadata
|
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
streams = container.streams
|
streams = container.streams
|
||||||
@ -467,8 +487,10 @@ class VideoFromComponents(VideoInput):
|
|||||||
Class representing video input from tensors.
|
Class representing video input from tensors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, components: VideoComponents):
|
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||||
self.__components = components
|
self.__components = components
|
||||||
|
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||||
|
self.__bit_depth = bit_depth
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
def get_components(self) -> VideoComponents:
|
||||||
return VideoComponents(
|
return VideoComponents(
|
||||||
@ -477,18 +499,26 @@ class VideoFromComponents(VideoInput):
|
|||||||
frame_rate=self.__components.frame_rate,
|
frame_rate=self.__components.frame_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_bit_depth(self) -> int:
|
||||||
|
return self.__bit_depth
|
||||||
|
|
||||||
def save_to(
|
def save_to(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: int | None = None,
|
||||||
):
|
):
|
||||||
"""Save the video to a file path or BytesIO buffer."""
|
"""Save the video to a file path or BytesIO buffer."""
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||||
|
if bit_depth is None:
|
||||||
|
bit_depth = self.__bit_depth
|
||||||
|
is_10bit = bit_depth >= 10
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||||
extra_kwargs["format"] = format.value
|
extra_kwargs["format"] = format.value
|
||||||
@ -504,10 +534,11 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
# Create a video stream
|
# Create a video stream
|
||||||
|
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
video_stream.width = self.__components.images.shape[2]
|
video_stream.width = self.__components.images.shape[2]
|
||||||
video_stream.height = self.__components.images.shape[1]
|
video_stream.height = self.__components.images.shape[1]
|
||||||
video_stream.pix_fmt = 'yuv420p'
|
video_stream.pix_fmt = pix_fmt
|
||||||
|
|
||||||
# Create an audio stream
|
# Create an audio stream
|
||||||
audio_sample_rate = 1
|
audio_sample_rate = 1
|
||||||
@ -521,9 +552,14 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
# Encode video
|
# Encode video
|
||||||
for i, frame in enumerate(self.__components.images):
|
for i, frame in enumerate(self.__components.images):
|
||||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
if is_10bit:
|
||||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||||
|
else:
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format=pix_fmt)
|
||||||
packet = video_stream.encode(frame)
|
packet = video_stream.encode(frame)
|
||||||
output.mux(packet)
|
output.mux(packet)
|
||||||
|
|
||||||
|
|||||||
@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
|
|||||||
id: Optional[str] = Field(None, description='Task ID')
|
id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
|
||||||
|
|
||||||
class RunwayTaskStatusEnum(str, Enum):
|
|
||||||
SUCCEEDED = 'SUCCEEDED'
|
|
||||||
RUNNING = 'RUNNING'
|
|
||||||
FAILED = 'FAILED'
|
|
||||||
PENDING = 'PENDING'
|
|
||||||
CANCELLED = 'CANCELLED'
|
|
||||||
THROTTLED = 'THROTTLED'
|
|
||||||
|
|
||||||
|
|
||||||
class RunwayTaskStatusResponse(BaseModel):
|
class RunwayTaskStatusResponse(BaseModel):
|
||||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||||
id: str = Field(..., description='Task ID')
|
id: str = Field(..., description='Task ID')
|
||||||
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
|||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=1.0,
|
le=1.0,
|
||||||
)
|
)
|
||||||
status: RunwayTaskStatusEnum
|
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||||
|
|
||||||
|
|
||||||
class Model4(str, Enum):
|
class Model4(str, Enum):
|
||||||
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
|
|||||||
|
|
||||||
class RunwayTextToImageResponse(BaseModel):
|
class RunwayTextToImageResponse(BaseModel):
|
||||||
id: Optional[str] = Field(None, description='Task ID')
|
id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2IO:
|
||||||
|
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||||
|
|
||||||
|
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||||
|
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||||
|
|
||||||
|
|
||||||
|
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||||
|
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||||
|
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||||
|
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||||
|
|
||||||
|
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||||
|
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||||
|
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeItem:
|
||||||
|
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||||
|
|
||||||
|
def __init__(self, image, mode: str, value: float):
|
||||||
|
self.image = image
|
||||||
|
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeChain:
|
||||||
|
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||||
|
|
||||||
|
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||||
|
c = RunwayAleph2KeyframeChain()
|
||||||
|
c.items = list(self.items)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImageItem:
|
||||||
|
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||||
|
|
||||||
|
def __init__(self, image, mode: str, value: float):
|
||||||
|
self.image = image
|
||||||
|
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImageChain:
|
||||||
|
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||||
|
|
||||||
|
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||||
|
c = RunwayAleph2PromptImageChain()
|
||||||
|
c.items = list(self.items)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||||
|
seconds: float = Field(
|
||||||
|
...,
|
||||||
|
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
uri: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeAt(BaseModel):
|
||||||
|
at: float = Field(
|
||||||
|
...,
|
||||||
|
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
)
|
||||||
|
uri: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2TimestampPosition(BaseModel):
|
||||||
|
type: str = Field(default="timestamp")
|
||||||
|
timestampSeconds: float = Field(
|
||||||
|
...,
|
||||||
|
description="Absolute timestamp in seconds from the start of the output video.",
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2RelativePosition(BaseModel):
|
||||||
|
type: str = Field(default="position")
|
||||||
|
positionPercentage: float = Field(
|
||||||
|
...,
|
||||||
|
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImage(BaseModel):
|
||||||
|
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||||
|
uri: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2ContentModeration(BaseModel):
|
||||||
|
publicFigureThreshold: str = Field(
|
||||||
|
...,
|
||||||
|
description='When set to "low", the content moderation system is less strict about '
|
||||||
|
'recognizable public figures. One of "auto" or "low".',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2Request(BaseModel):
|
||||||
|
model: str = Field(default="aleph2")
|
||||||
|
promptText: str = Field(
|
||||||
|
...,
|
||||||
|
description="A non-empty string describing what should appear in the output.",
|
||||||
|
min_length=1,
|
||||||
|
max_length=1000,
|
||||||
|
)
|
||||||
|
videoUri: str = Field(...)
|
||||||
|
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||||
|
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||||
|
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||||
|
)
|
||||||
|
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2Response(BaseModel):
|
||||||
|
id: str | None = Field(None, description="Task ID")
|
||||||
|
|||||||
@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
|
|||||||
Model4,
|
Model4,
|
||||||
ReferenceImage,
|
ReferenceImage,
|
||||||
RunwayTextToImageAspectRatioEnum,
|
RunwayTextToImageAspectRatioEnum,
|
||||||
|
RunwayAleph2IO,
|
||||||
|
RunwayAleph2KeyframeChain,
|
||||||
|
RunwayAleph2KeyframeItem,
|
||||||
|
RunwayAleph2PromptImageChain,
|
||||||
|
RunwayAleph2PromptImageItem,
|
||||||
|
RunwayAleph2Request,
|
||||||
|
RunwayAleph2Response,
|
||||||
|
RunwayAleph2KeyframeSeconds,
|
||||||
|
RunwayAleph2KeyframeAt,
|
||||||
|
RunwayAleph2PromptImage,
|
||||||
|
RunwayAleph2TimestampPosition,
|
||||||
|
RunwayAleph2RelativePosition,
|
||||||
|
RunwayAleph2ContentModeration,
|
||||||
|
KEYFRAME_MODE_SECONDS,
|
||||||
|
KEYFRAME_MODE_AT,
|
||||||
|
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||||
|
PROMPT_IMAGE_MODE_POSITION,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
|
validate_video_duration,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
|
upload_image_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||||
|
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||||
|
|
||||||
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
|||||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||||
|
|
||||||
|
|
||||||
class RunwayApiError(Exception):
|
|
||||||
"""Base exception for Runway API errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||||
|
|
||||||
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_progress_from_task_status(
|
|
||||||
response: TaskStatusResponse,
|
|
||||||
) -> float | None:
|
|
||||||
if hasattr(response, "progress") and response.progress is not None:
|
|
||||||
return response.progress * 100
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||||
"""Returns the image URL from the task status response if it exists."""
|
"""Returns the image URL from the task status response if it exists."""
|
||||||
if hasattr(response, "output") and len(response.output) > 0:
|
if hasattr(response, "output") and len(response.output) > 0:
|
||||||
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
|||||||
async def get_response(
|
async def get_response(
|
||||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||||
) -> TaskStatusResponse:
|
) -> TaskStatusResponse:
|
||||||
"""Poll the task status until it is finished then get the response."""
|
|
||||||
return await poll_op(
|
return await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: r.status.value,
|
status_extractor=lambda r: r.status,
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
progress_extractor=extract_progress_from_task_status,
|
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +133,7 @@ async def generate_video(
|
|||||||
|
|
||||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||||
if not final_response.output:
|
if not final_response.output:
|
||||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||||
|
|
||||||
video_url = get_video_url_from_task_status(final_response)
|
video_url = get_video_url_from_task_status(final_response)
|
||||||
return await download_url_to_video_output(video_url)
|
return await download_url_to_video_output(video_url)
|
||||||
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
|||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
)
|
)
|
||||||
if len(download_urls) != 2:
|
if len(download_urls) != 2:
|
||||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await generate_video(
|
await generate_video(
|
||||||
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
|||||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||||
)
|
)
|
||||||
if not final_response.output:
|
if not final_response.output:
|
||||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||||
|
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||||
|
|
||||||
|
|
||||||
|
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||||
|
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RunwayAleph2KeyframeNode",
|
||||||
|
display_name="Runway Aleph2 Keyframe",
|
||||||
|
category="partner/video/Runway",
|
||||||
|
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||||
|
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||||
|
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||||
|
"'keyframes' input below.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"timing",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_ABSOLUTE,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"seconds",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=30.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_FRACTION,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"fraction",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Where in the input video this image applies, "
|
||||||
|
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="How to place this image on the input video's timeline.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||||
|
"keyframes",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional earlier keyframes to chain with this one.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
image: Input.Image,
|
||||||
|
timing: dict,
|
||||||
|
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||||
|
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||||
|
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||||
|
else:
|
||||||
|
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||||
|
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||||
|
return IO.NodeOutput(chain)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RunwayAleph2PromptImageNode",
|
||||||
|
display_name="Runway Aleph2 Prompt Image",
|
||||||
|
category="partner/video/Runway",
|
||||||
|
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||||
|
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||||
|
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||||
|
"'prompt_images' input below.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"position",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_ABSOLUTE,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"seconds",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=30.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_FRACTION,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"fraction",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Where in the output video this image applies, "
|
||||||
|
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="How to place this image on the output video's timeline.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||||
|
"prompt_images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional earlier prompt images to chain with this one.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
image: Input.Image,
|
||||||
|
position: dict,
|
||||||
|
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||||
|
if position["position"] == _TIMING_ABSOLUTE:
|
||||||
|
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||||
|
else:
|
||||||
|
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||||
|
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||||
|
return IO.NodeOutput(chain)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RunwayAleph2VideoToVideoNode",
|
||||||
|
display_name="Runway Aleph2 Video to Video",
|
||||||
|
category="partner/video/Runway",
|
||||||
|
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||||
|
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||||
|
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||||
|
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||||
|
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||||
|
),
|
||||||
|
IO.Video.Input(
|
||||||
|
"video",
|
||||||
|
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=4294967295,
|
||||||
|
step=1,
|
||||||
|
control_after_generate=True,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Random seed for generation",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"public_figure_threshold",
|
||||||
|
options=["auto", "low"],
|
||||||
|
default="low",
|
||||||
|
tooltip="Content moderation for recognizable public figures.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||||
|
"keyframes",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||||
|
"Use keyframes or prompt images, not both.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||||
|
"prompt_images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||||
|
"Use keyframes or prompt images, not both.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
seed: int,
|
||||||
|
public_figure_threshold: str = "low",
|
||||||
|
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||||
|
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=1000)
|
||||||
|
validate_video_duration(
|
||||||
|
video,
|
||||||
|
min_duration=2.0,
|
||||||
|
max_duration=30.0,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
fps = float(video.get_frame_rate())
|
||||||
|
except Exception:
|
||||||
|
fps = None
|
||||||
|
if fps is not None and fps > 30.0 + 0.01:
|
||||||
|
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||||
|
|
||||||
|
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||||
|
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||||
|
|
||||||
|
video_duration: float | None = None
|
||||||
|
try:
|
||||||
|
video_duration = video.get_duration()
|
||||||
|
except Exception:
|
||||||
|
video_duration = None
|
||||||
|
|
||||||
|
def _check_seconds(value: float, label: str) -> None:
|
||||||
|
if video_duration is not None and value > video_duration + 0.0001:
|
||||||
|
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||||
|
|
||||||
|
video_url = await upload_video_to_comfyapi(cls, video)
|
||||||
|
|
||||||
|
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||||
|
if keyframes is not None:
|
||||||
|
if len(keyframes.items) > 5:
|
||||||
|
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||||
|
for item in keyframes.items:
|
||||||
|
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||||
|
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||||
|
_check_seconds(item.value, "Keyframe timestamp")
|
||||||
|
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||||
|
else:
|
||||||
|
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||||
|
|
||||||
|
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||||
|
if prompt_images is not None:
|
||||||
|
if len(prompt_images.items) > 5:
|
||||||
|
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||||
|
for item in prompt_images.items:
|
||||||
|
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||||
|
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||||
|
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||||
|
_check_seconds(item.value, "Prompt image timestamp")
|
||||||
|
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||||
|
else:
|
||||||
|
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||||
|
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||||
|
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||||
|
response_model=RunwayAleph2Response,
|
||||||
|
data=RunwayAleph2Request(
|
||||||
|
promptText=prompt,
|
||||||
|
videoUri=video_url,
|
||||||
|
seed=seed,
|
||||||
|
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||||
|
keyframes=keyframe_models or None,
|
||||||
|
promptImage=prompt_image_models or None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = await get_response(cls, initial_response.id)
|
||||||
|
if not final_response.output:
|
||||||
|
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||||
|
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||||
|
|
||||||
|
|
||||||
class RunwayExtension(ComfyExtension):
|
class RunwayExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
|
|||||||
RunwayImageToVideoNodeGen3a,
|
RunwayImageToVideoNodeGen3a,
|
||||||
RunwayImageToVideoNodeGen4,
|
RunwayImageToVideoNodeGen4,
|
||||||
RunwayTextToImageNode,
|
RunwayTextToImageNode,
|
||||||
|
RunwayAleph2VideoToVideoNode,
|
||||||
|
RunwayAleph2KeyframeNode,
|
||||||
|
RunwayAleph2PromptImageNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode):
|
|||||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||||
|
io.Int.Input(
|
||||||
|
"bit_depth",
|
||||||
|
min=8,
|
||||||
|
max=10,
|
||||||
|
default=8,
|
||||||
|
step=2,
|
||||||
|
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
|
||||||
|
" banding, but some players and downstream nodes may not support it.",
|
||||||
|
optional=True,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Video.Output(),
|
io.Video.Output(),
|
||||||
@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
def execute(
|
||||||
|
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
|
||||||
|
) -> io.NodeOutput:
|
||||||
return io.NodeOutput(
|
return io.NodeOutput(
|
||||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
InputImpl.VideoFromComponents(
|
||||||
|
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
|
||||||
|
bit_depth=bit_depth,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
class GetVideoComponents(io.ComfyNode):
|
class GetVideoComponents(io.ComfyNode):
|
||||||
@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||||
display_name="Get Video Components",
|
display_name="Get Video Components",
|
||||||
category="video",
|
category="video",
|
||||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||||
],
|
],
|
||||||
@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
io.Image.Output(display_name="images"),
|
io.Image.Output(display_name="images"),
|
||||||
io.Audio.Output(display_name="audio"),
|
io.Audio.Output(display_name="audio"),
|
||||||
io.Float.Output(display_name="fps"),
|
io.Float.Output(display_name="fps"),
|
||||||
|
io.Int.Output(display_name="bit_depth"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||||
components = video.get_components()
|
components = video.get_components()
|
||||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
|
||||||
|
|
||||||
|
|
||||||
class LoadVideo(io.ComfyNode):
|
class LoadVideo(io.ComfyNode):
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import logging
|
|||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from comfy.deploy_environment import get_deploy_environment
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
@ -690,6 +691,7 @@ class PromptServer():
|
|||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"pytorch_version": comfy.model_management.torch_version,
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"deploy_environment": get_deploy_environment(),
|
||||||
"argv": sys.argv
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": device_entries
|
"devices": device_entries
|
||||||
|
|||||||
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
from fractions import Fraction
|
||||||
|
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||||
|
from comfy_api.latest._util.video_types import VideoComponents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def gradient_components():
|
||||||
|
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||||
|
width, height, frames = 64, 64, 3
|
||||||
|
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||||
|
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def src8(gradient_components, tmp_path_factory):
|
||||||
|
"""8-bit h264 mp4 (Create Video default)"""
|
||||||
|
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||||
|
VideoFromComponents(gradient_components).save_to(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def src10(gradient_components, tmp_path_factory):
|
||||||
|
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
|
||||||
|
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||||
|
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def probe(path):
|
||||||
|
"""(codec, pix_fmt, bit_depth) of the first video stream"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
|
||||||
|
|
||||||
|
|
||||||
|
def decoded_levels(path):
|
||||||
|
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
frame = next(container.decode(container.streams.video[0]))
|
||||||
|
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||||
|
|
||||||
|
|
||||||
|
def video_packet_bytes(path):
|
||||||
|
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_video_bit_depth(src8, src10):
|
||||||
|
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
|
||||||
|
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||||
|
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||||
|
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
|
||||||
|
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
|
||||||
|
for name, src in [("p8", src8), ("p10", src10)]:
|
||||||
|
path = str(tmp_path / f"{name}.mp4")
|
||||||
|
VideoFromFile(src).save_to(path)
|
||||||
|
assert probe(path) == probe(src)
|
||||||
|
assert video_packet_bytes(path) == video_packet_bytes(src)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
|
||||||
|
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
|
||||||
|
down = str(tmp_path / "down8.mp4")
|
||||||
|
VideoFromFile(src10).save_to(down, bit_depth=8)
|
||||||
|
assert probe(down) == ("h264", "yuv420p", 8)
|
||||||
|
|
||||||
|
up = str(tmp_path / "up10.mp4")
|
||||||
|
VideoFromFile(src8).save_to(up, bit_depth=10)
|
||||||
|
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||||
|
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
|
||||||
|
path = str(tmp_path / "trim.mp4")
|
||||||
|
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||||
|
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bit_depth(gradient_components, src8, src10):
|
||||||
|
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
|
||||||
|
assert VideoFromFile(src8).get_bit_depth() == 8
|
||||||
|
assert VideoFromFile(src10).get_bit_depth() == 10
|
||||||
|
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
|
||||||
|
assert VideoFromComponents(gradient_components).get_bit_depth() == 8
|
||||||
Loading…
Reference in New Issue
Block a user