mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 09:19:46 +08:00
Merge branch 'master' into expose-deploy-environment
This commit is contained in:
commit
55642474df
@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
|||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) |
|
||||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
|
|||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
|||||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||||
manager_group = parser.add_mutually_exclusive_group()
|
manager_group = parser.add_mutually_exclusive_group()
|
||||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.")
|
||||||
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
@ -249,6 +250,9 @@ else:
|
|||||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||||
|
|
||||||
|
if args.high_ram:
|
||||||
|
args.cache_classic = True
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
args.auto_launch = True
|
args.auto_launch = True
|
||||||
|
|
||||||
@ -258,6 +262,10 @@ if args.disable_auto_launch:
|
|||||||
if args.force_fp16:
|
if args.force_fp16:
|
||||||
args.fp16_unet = True
|
args.fp16_unet = True
|
||||||
|
|
||||||
|
# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'.
|
||||||
|
if args.enable_manager_legacy_ui:
|
||||||
|
args.enable_manager = True
|
||||||
|
|
||||||
|
|
||||||
# '--fast' is not provided, use an empty set
|
# '--fast' is not provided, use an empty set
|
||||||
if args.fast is None:
|
if args.fast is None:
|
||||||
|
|||||||
@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
|
|||||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, dtype):
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||||
emb = emb.to(self.mlp_in.weight.dtype)
|
emb = emb.to(dtype)
|
||||||
emb = F.silu(self.mlp_in(emb))
|
emb = F.silu(self.mlp_in(emb))
|
||||||
return self.mlp_out(emb)
|
return self.mlp_out(emb)
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
|
|||||||
x = x * output_image_mask
|
x = x * output_image_mask
|
||||||
h = self.input_proj(x) * output_image_mask
|
h = self.input_proj(x) * output_image_mask
|
||||||
|
|
||||||
t_cond = self.t_embedding(t)
|
t_cond = self.t_embedding(t, dtype=x.dtype)
|
||||||
if t.dim() == 1:
|
if t.dim() == 1:
|
||||||
t_cond = t_cond.unsqueeze(1)
|
t_cond = t_cond.unsqueeze(1)
|
||||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from comfy.ldm.lightricks.model import Timesteps
|
from comfy.ldm.lightricks.model import Timesteps
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis):
|
|||||||
if x.shape[1] == 0:
|
if x.shape[1] == 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
return apply_rope1(x, freqs_cis)
|
||||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
|
||||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -643,6 +643,8 @@ def free_pins(size, evict_active=False):
|
|||||||
return freed_total
|
return freed_total
|
||||||
|
|
||||||
def ensure_pin_budget(size, evict_active=False):
|
def ensure_pin_budget(size, evict_active=False):
|
||||||
|
if args.high_ram:
|
||||||
|
return True
|
||||||
if args.fast_disk:
|
if args.fast_disk:
|
||||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||||
else:
|
else:
|
||||||
@ -1496,6 +1498,8 @@ if not args.disable_pinned_memory:
|
|||||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
def pinned_hostbuf_size(size):
|
def pinned_hostbuf_size(size):
|
||||||
|
if args.high_ram:
|
||||||
|
return max(0, int(size * 2))
|
||||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||||
|
|
||||||
def discard_cuda_async_error():
|
def discard_cuda_async_error():
|
||||||
|
|||||||
@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
if pin is not None:
|
if pin is not None:
|
||||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||||
return
|
return
|
||||||
if signature is None:
|
if signature is None or args.high_ram:
|
||||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||||
|
|||||||
@ -27,10 +27,13 @@ class VideoInput(ABC):
|
|||||||
path: Union[str, IO[bytes]],
|
path: Union[str, IO[bytes]],
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Abstract method to save the video input to a file.
|
Abstract method to save the video input to a file.
|
||||||
|
|
||||||
|
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -83,6 +86,14 @@ class VideoInput(ABC):
|
|||||||
components = self.get_components()
|
components = self.get_components()
|
||||||
return components.images.shape[2], components.images.shape[1]
|
return components.images.shape[2], components.images.shape[1]
|
||||||
|
|
||||||
|
def get_bit_depth(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the bit depth of the video (e.g. 8 or 10).
|
||||||
|
|
||||||
|
Default implementation returns 8; subclasses report their real depth.
|
||||||
|
"""
|
||||||
|
return 8
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
def get_duration(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns the duration of the video in seconds.
|
Returns the duration of the video in seconds.
|
||||||
|
|||||||
@ -52,6 +52,12 @@ def get_open_write_kwargs(
|
|||||||
return open_kwargs
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def video_stream_bit_depth(stream) -> int:
|
||||||
|
if stream is None or stream.format is None or not stream.format.components:
|
||||||
|
return 8
|
||||||
|
return max(component.bits for component in stream.format.components)
|
||||||
|
|
||||||
|
|
||||||
class VideoFromFile(VideoInput):
|
class VideoFromFile(VideoInput):
|
||||||
"""
|
"""
|
||||||
Class representing video input from a file.
|
Class representing video input from a file.
|
||||||
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
|
|||||||
return stream.width, stream.height
|
return stream.width, stream.height
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_bit_depth(self) -> int:
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||||
|
return video_stream_bit_depth(video_stream)
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
def get_duration(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns the duration of the video in seconds.
|
Returns the duration of the video in seconds.
|
||||||
@ -377,25 +390,32 @@ class VideoFromFile(VideoInput):
|
|||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: int | None = None,
|
||||||
):
|
):
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
with av.open(self.__file, mode='r') as container:
|
with av.open(self.__file, mode='r') as container:
|
||||||
container_format = container.format.name
|
container_format = container.format.name
|
||||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||||
|
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||||
|
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||||
reuse_streams = True
|
reuse_streams = True
|
||||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||||
|
reuse_streams = False
|
||||||
if self.__start_time or self.__duration:
|
if self.__start_time or self.__duration:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
|
||||||
if not reuse_streams:
|
if not reuse_streams:
|
||||||
|
if bit_depth is None:
|
||||||
|
bit_depth = source_bit_depth
|
||||||
components = self.get_components_internal(container)
|
components = self.get_components_internal(container)
|
||||||
video = VideoFromComponents(components)
|
video = VideoFromComponents(components)
|
||||||
return video.save_to(
|
return video.save_to(
|
||||||
path, format=format, codec=codec, metadata=metadata
|
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
streams = container.streams
|
streams = container.streams
|
||||||
@ -451,8 +471,10 @@ class VideoFromComponents(VideoInput):
|
|||||||
Class representing video input from tensors.
|
Class representing video input from tensors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, components: VideoComponents):
|
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||||
self.__components = components
|
self.__components = components
|
||||||
|
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||||
|
self.__bit_depth = bit_depth
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
def get_components(self) -> VideoComponents:
|
||||||
return VideoComponents(
|
return VideoComponents(
|
||||||
@ -461,18 +483,26 @@ class VideoFromComponents(VideoInput):
|
|||||||
frame_rate=self.__components.frame_rate,
|
frame_rate=self.__components.frame_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_bit_depth(self) -> int:
|
||||||
|
return self.__bit_depth
|
||||||
|
|
||||||
def save_to(
|
def save_to(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: int | None = None,
|
||||||
):
|
):
|
||||||
"""Save the video to a file path or BytesIO buffer."""
|
"""Save the video to a file path or BytesIO buffer."""
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||||
|
if bit_depth is None:
|
||||||
|
bit_depth = self.__bit_depth
|
||||||
|
is_10bit = bit_depth >= 10
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||||
extra_kwargs["format"] = format.value
|
extra_kwargs["format"] = format.value
|
||||||
@ -488,10 +518,11 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
# Create a video stream
|
# Create a video stream
|
||||||
|
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
video_stream.width = self.__components.images.shape[2]
|
video_stream.width = self.__components.images.shape[2]
|
||||||
video_stream.height = self.__components.images.shape[1]
|
video_stream.height = self.__components.images.shape[1]
|
||||||
video_stream.pix_fmt = 'yuv420p'
|
video_stream.pix_fmt = pix_fmt
|
||||||
|
|
||||||
# Create an audio stream
|
# Create an audio stream
|
||||||
audio_sample_rate = 1
|
audio_sample_rate = 1
|
||||||
@ -505,9 +536,14 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
# Encode video
|
# Encode video
|
||||||
for i, frame in enumerate(self.__components.images):
|
for i, frame in enumerate(self.__components.images):
|
||||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
if is_10bit:
|
||||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||||
|
else:
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format=pix_fmt)
|
||||||
packet = video_stream.encode(frame)
|
packet = video_stream.encode(frame)
|
||||||
output.mux(packet)
|
output.mux(packet)
|
||||||
|
|
||||||
|
|||||||
@ -1400,7 +1400,8 @@ class V3Data(TypedDict):
|
|||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
extra_pnginfo: Any, dynprompt: Any,
|
extra_pnginfo: Any, dynprompt: Any,
|
||||||
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
|
auth_token_comfy_org: str, api_key_comfy_org: str,
|
||||||
|
comfy_usage_source: str = None, **kwargs):
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -1413,6 +1414,8 @@ class HiddenHolder:
|
|||||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||||
self.api_key_comfy_org = api_key_comfy_org
|
self.api_key_comfy_org = api_key_comfy_org
|
||||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||||
|
self.comfy_usage_source = comfy_usage_source
|
||||||
|
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
'''If hidden variable not found, return None.'''
|
'''If hidden variable not found, return None.'''
|
||||||
@ -1429,6 +1432,7 @@ class HiddenHolder:
|
|||||||
dynprompt=d.get(Hidden.dynprompt, None),
|
dynprompt=d.get(Hidden.dynprompt, None),
|
||||||
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
|
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
|
||||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||||
|
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1451,6 +1455,8 @@ class Hidden(str, Enum):
|
|||||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||||
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||||
|
comfy_usage_source = "COMFY_USAGE_SOURCE"
|
||||||
|
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1654,6 +1660,8 @@ class Schema:
|
|||||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||||
if Hidden.api_key_comfy_org not in self.hidden:
|
if Hidden.api_key_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.api_key_comfy_org)
|
self.hidden.append(Hidden.api_key_comfy_org)
|
||||||
|
if Hidden.comfy_usage_source not in self.hidden:
|
||||||
|
self.hidden.append(Hidden.comfy_usage_source)
|
||||||
# if is an output_node, will need prompt and extra_pnginfo
|
# if is an output_node, will need prompt and extra_pnginfo
|
||||||
if self.is_output_node:
|
if self.is_output_node:
|
||||||
if Hidden.prompt not in self.hidden:
|
if Hidden.prompt not in self.hidden:
|
||||||
|
|||||||
@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
|
|||||||
id: Optional[str] = Field(None, description='Task ID')
|
id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
|
||||||
|
|
||||||
class RunwayTaskStatusEnum(str, Enum):
|
|
||||||
SUCCEEDED = 'SUCCEEDED'
|
|
||||||
RUNNING = 'RUNNING'
|
|
||||||
FAILED = 'FAILED'
|
|
||||||
PENDING = 'PENDING'
|
|
||||||
CANCELLED = 'CANCELLED'
|
|
||||||
THROTTLED = 'THROTTLED'
|
|
||||||
|
|
||||||
|
|
||||||
class RunwayTaskStatusResponse(BaseModel):
|
class RunwayTaskStatusResponse(BaseModel):
|
||||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||||
id: str = Field(..., description='Task ID')
|
id: str = Field(..., description='Task ID')
|
||||||
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
|||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=1.0,
|
le=1.0,
|
||||||
)
|
)
|
||||||
status: RunwayTaskStatusEnum
|
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||||
|
|
||||||
|
|
||||||
class Model4(str, Enum):
|
class Model4(str, Enum):
|
||||||
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
|
|||||||
|
|
||||||
class RunwayTextToImageResponse(BaseModel):
|
class RunwayTextToImageResponse(BaseModel):
|
||||||
id: Optional[str] = Field(None, description='Task ID')
|
id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2IO:
|
||||||
|
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||||
|
|
||||||
|
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||||
|
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||||
|
|
||||||
|
|
||||||
|
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||||
|
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||||
|
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||||
|
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||||
|
|
||||||
|
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||||
|
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||||
|
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeItem:
|
||||||
|
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||||
|
|
||||||
|
def __init__(self, image, mode: str, value: float):
|
||||||
|
self.image = image
|
||||||
|
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeChain:
|
||||||
|
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||||
|
|
||||||
|
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||||
|
c = RunwayAleph2KeyframeChain()
|
||||||
|
c.items = list(self.items)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImageItem:
|
||||||
|
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||||
|
|
||||||
|
def __init__(self, image, mode: str, value: float):
|
||||||
|
self.image = image
|
||||||
|
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImageChain:
|
||||||
|
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||||
|
|
||||||
|
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||||
|
c = RunwayAleph2PromptImageChain()
|
||||||
|
c.items = list(self.items)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||||
|
seconds: float = Field(
|
||||||
|
...,
|
||||||
|
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
uri: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeAt(BaseModel):
|
||||||
|
at: float = Field(
|
||||||
|
...,
|
||||||
|
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
)
|
||||||
|
uri: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2TimestampPosition(BaseModel):
|
||||||
|
type: str = Field(default="timestamp")
|
||||||
|
timestampSeconds: float = Field(
|
||||||
|
...,
|
||||||
|
description="Absolute timestamp in seconds from the start of the output video.",
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2RelativePosition(BaseModel):
|
||||||
|
type: str = Field(default="position")
|
||||||
|
positionPercentage: float = Field(
|
||||||
|
...,
|
||||||
|
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImage(BaseModel):
|
||||||
|
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||||
|
uri: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2ContentModeration(BaseModel):
|
||||||
|
publicFigureThreshold: str = Field(
|
||||||
|
...,
|
||||||
|
description='When set to "low", the content moderation system is less strict about '
|
||||||
|
'recognizable public figures. One of "auto" or "low".',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2Request(BaseModel):
|
||||||
|
model: str = Field(default="aleph2")
|
||||||
|
promptText: str = Field(
|
||||||
|
...,
|
||||||
|
description="A non-empty string describing what should appear in the output.",
|
||||||
|
min_length=1,
|
||||||
|
max_length=1000,
|
||||||
|
)
|
||||||
|
videoUri: str = Field(...)
|
||||||
|
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||||
|
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||||
|
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||||
|
)
|
||||||
|
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2Response(BaseModel):
|
||||||
|
id: str | None = Field(None, description="Task ID")
|
||||||
|
|||||||
@ -289,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -357,7 +357,7 @@ class BriaVideoGreenScreen(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -433,7 +433,7 @@ class BriaVideoReplaceBackground(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -452,7 +452,10 @@ class BriaVideoReplaceBackground(IO.ComfyNode):
|
|||||||
validate_video_duration(background_video, max_duration=60.0)
|
validate_video_duration(background_video, max_duration=60.0)
|
||||||
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||||
else:
|
else:
|
||||||
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
|
# Bria's replace_background 500s on RGBA, so drop the alpha channel before upload.
|
||||||
|
background_url = await upload_image_to_comfyapi(
|
||||||
|
cls, background_image[:, :, :, :3], wait_label="Uploading background"
|
||||||
|
)
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||||
@ -530,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -571,7 +574,7 @@ class BriaExtension(ComfyExtension):
|
|||||||
BriaRemoveImageBackground,
|
BriaRemoveImageBackground,
|
||||||
BriaRemoveVideoBackground,
|
BriaRemoveVideoBackground,
|
||||||
BriaVideoGreenScreen,
|
BriaVideoGreenScreen,
|
||||||
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
|
BriaVideoReplaceBackground,
|
||||||
BriaTransparentVideoBackground,
|
BriaTransparentVideoBackground,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
|
|||||||
Model4,
|
Model4,
|
||||||
ReferenceImage,
|
ReferenceImage,
|
||||||
RunwayTextToImageAspectRatioEnum,
|
RunwayTextToImageAspectRatioEnum,
|
||||||
|
RunwayAleph2IO,
|
||||||
|
RunwayAleph2KeyframeChain,
|
||||||
|
RunwayAleph2KeyframeItem,
|
||||||
|
RunwayAleph2PromptImageChain,
|
||||||
|
RunwayAleph2PromptImageItem,
|
||||||
|
RunwayAleph2Request,
|
||||||
|
RunwayAleph2Response,
|
||||||
|
RunwayAleph2KeyframeSeconds,
|
||||||
|
RunwayAleph2KeyframeAt,
|
||||||
|
RunwayAleph2PromptImage,
|
||||||
|
RunwayAleph2TimestampPosition,
|
||||||
|
RunwayAleph2RelativePosition,
|
||||||
|
RunwayAleph2ContentModeration,
|
||||||
|
KEYFRAME_MODE_SECONDS,
|
||||||
|
KEYFRAME_MODE_AT,
|
||||||
|
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||||
|
PROMPT_IMAGE_MODE_POSITION,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
|
validate_video_duration,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
|
upload_image_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||||
|
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||||
|
|
||||||
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
|||||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||||
|
|
||||||
|
|
||||||
class RunwayApiError(Exception):
|
|
||||||
"""Base exception for Runway API errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||||
|
|
||||||
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_progress_from_task_status(
|
|
||||||
response: TaskStatusResponse,
|
|
||||||
) -> float | None:
|
|
||||||
if hasattr(response, "progress") and response.progress is not None:
|
|
||||||
return response.progress * 100
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||||
"""Returns the image URL from the task status response if it exists."""
|
"""Returns the image URL from the task status response if it exists."""
|
||||||
if hasattr(response, "output") and len(response.output) > 0:
|
if hasattr(response, "output") and len(response.output) > 0:
|
||||||
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
|||||||
async def get_response(
|
async def get_response(
|
||||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||||
) -> TaskStatusResponse:
|
) -> TaskStatusResponse:
|
||||||
"""Poll the task status until it is finished then get the response."""
|
|
||||||
return await poll_op(
|
return await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: r.status.value,
|
status_extractor=lambda r: r.status,
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
progress_extractor=extract_progress_from_task_status,
|
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +133,7 @@ async def generate_video(
|
|||||||
|
|
||||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||||
if not final_response.output:
|
if not final_response.output:
|
||||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||||
|
|
||||||
video_url = get_video_url_from_task_status(final_response)
|
video_url = get_video_url_from_task_status(final_response)
|
||||||
return await download_url_to_video_output(video_url)
|
return await download_url_to_video_output(video_url)
|
||||||
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
|||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
)
|
)
|
||||||
if len(download_urls) != 2:
|
if len(download_urls) != 2:
|
||||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await generate_video(
|
await generate_video(
|
||||||
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
|||||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||||
)
|
)
|
||||||
if not final_response.output:
|
if not final_response.output:
|
||||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||||
|
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||||
|
|
||||||
|
|
||||||
|
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||||
|
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RunwayAleph2KeyframeNode",
|
||||||
|
display_name="Runway Aleph2 Keyframe",
|
||||||
|
category="partner/video/Runway",
|
||||||
|
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||||
|
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||||
|
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||||
|
"'keyframes' input below.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"timing",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_ABSOLUTE,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"seconds",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=30.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_FRACTION,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"fraction",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Where in the input video this image applies, "
|
||||||
|
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="How to place this image on the input video's timeline.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||||
|
"keyframes",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional earlier keyframes to chain with this one.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
image: Input.Image,
|
||||||
|
timing: dict,
|
||||||
|
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||||
|
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||||
|
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||||
|
else:
|
||||||
|
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||||
|
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||||
|
return IO.NodeOutput(chain)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RunwayAleph2PromptImageNode",
|
||||||
|
display_name="Runway Aleph2 Prompt Image",
|
||||||
|
category="partner/video/Runway",
|
||||||
|
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||||
|
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||||
|
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||||
|
"'prompt_images' input below.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"position",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_ABSOLUTE,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"seconds",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=30.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
_TIMING_FRACTION,
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"fraction",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Where in the output video this image applies, "
|
||||||
|
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="How to place this image on the output video's timeline.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||||
|
"prompt_images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional earlier prompt images to chain with this one.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
image: Input.Image,
|
||||||
|
position: dict,
|
||||||
|
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||||
|
if position["position"] == _TIMING_ABSOLUTE:
|
||||||
|
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||||
|
else:
|
||||||
|
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||||
|
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||||
|
return IO.NodeOutput(chain)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RunwayAleph2VideoToVideoNode",
|
||||||
|
display_name="Runway Aleph2 Video to Video",
|
||||||
|
category="partner/video/Runway",
|
||||||
|
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||||
|
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||||
|
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||||
|
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||||
|
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||||
|
),
|
||||||
|
IO.Video.Input(
|
||||||
|
"video",
|
||||||
|
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=4294967295,
|
||||||
|
step=1,
|
||||||
|
control_after_generate=True,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Random seed for generation",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"public_figure_threshold",
|
||||||
|
options=["auto", "low"],
|
||||||
|
default="low",
|
||||||
|
tooltip="Content moderation for recognizable public figures.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||||
|
"keyframes",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||||
|
"Use keyframes or prompt images, not both.",
|
||||||
|
),
|
||||||
|
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||||
|
"prompt_images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||||
|
"Use keyframes or prompt images, not both.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
seed: int,
|
||||||
|
public_figure_threshold: str = "low",
|
||||||
|
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||||
|
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=1000)
|
||||||
|
validate_video_duration(
|
||||||
|
video,
|
||||||
|
min_duration=2.0,
|
||||||
|
max_duration=30.0,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
fps = float(video.get_frame_rate())
|
||||||
|
except Exception:
|
||||||
|
fps = None
|
||||||
|
if fps is not None and fps > 30.0 + 0.01:
|
||||||
|
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||||
|
|
||||||
|
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||||
|
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||||
|
|
||||||
|
video_duration: float | None = None
|
||||||
|
try:
|
||||||
|
video_duration = video.get_duration()
|
||||||
|
except Exception:
|
||||||
|
video_duration = None
|
||||||
|
|
||||||
|
def _check_seconds(value: float, label: str) -> None:
|
||||||
|
if video_duration is not None and value > video_duration + 0.0001:
|
||||||
|
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||||
|
|
||||||
|
video_url = await upload_video_to_comfyapi(cls, video)
|
||||||
|
|
||||||
|
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||||
|
if keyframes is not None:
|
||||||
|
if len(keyframes.items) > 5:
|
||||||
|
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||||
|
for item in keyframes.items:
|
||||||
|
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||||
|
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||||
|
_check_seconds(item.value, "Keyframe timestamp")
|
||||||
|
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||||
|
else:
|
||||||
|
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||||
|
|
||||||
|
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||||
|
if prompt_images is not None:
|
||||||
|
if len(prompt_images.items) > 5:
|
||||||
|
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||||
|
for item in prompt_images.items:
|
||||||
|
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||||
|
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||||
|
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||||
|
_check_seconds(item.value, "Prompt image timestamp")
|
||||||
|
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||||
|
else:
|
||||||
|
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||||
|
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||||
|
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||||
|
response_model=RunwayAleph2Response,
|
||||||
|
data=RunwayAleph2Request(
|
||||||
|
promptText=prompt,
|
||||||
|
videoUri=video_url,
|
||||||
|
seed=seed,
|
||||||
|
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||||
|
keyframes=keyframe_models or None,
|
||||||
|
promptImage=prompt_image_models or None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = await get_response(cls, initial_response.id)
|
||||||
|
if not final_response.output:
|
||||||
|
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||||
|
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||||
|
|
||||||
|
|
||||||
class RunwayExtension(ComfyExtension):
|
class RunwayExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
|
|||||||
RunwayImageToVideoNodeGen3a,
|
RunwayImageToVideoNodeGen3a,
|
||||||
RunwayImageToVideoNodeGen4,
|
RunwayImageToVideoNodeGen4,
|
||||||
RunwayTextToImageNode,
|
RunwayTextToImageNode,
|
||||||
|
RunwayAleph2VideoToVideoNode,
|
||||||
|
RunwayAleph2KeyframeNode,
|
||||||
|
RunwayAleph2PromptImageNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
from comfy_api_nodes.util._helpers import (
|
from comfy_api_nodes.util._helpers import (
|
||||||
default_base_url,
|
default_base_url,
|
||||||
get_auth_header,
|
get_comfy_api_headers,
|
||||||
get_node_id,
|
get_node_id,
|
||||||
is_processing_interrupted,
|
is_processing_interrupted,
|
||||||
)
|
)
|
||||||
@ -174,8 +174,7 @@ async def _stream_sonilo_music(
|
|||||||
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||||
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||||
|
|
||||||
headers: dict[str, str] = {}
|
headers = get_comfy_api_headers(cls)
|
||||||
headers.update(get_auth_header(cls))
|
|
||||||
headers.update(endpoint.headers)
|
headers.update(endpoint.headers)
|
||||||
|
|
||||||
node_id = get_node_id(cls)
|
node_id = get_node_id(cls)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from io import BytesIO
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from comfy.deploy_environment import get_deploy_environment
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
@ -35,6 +36,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
|
||||||
|
"""Source of the prompt that triggered this API node.
|
||||||
|
|
||||||
|
Defaults to "comfyui-api" when the submitting client didn't identify itself,
|
||||||
|
i.e. a direct API call to this server.
|
||||||
|
"""
|
||||||
|
return node_cls.hidden.comfy_usage_source or "comfyui-api"
|
||||||
|
|
||||||
|
|
||||||
|
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||||
|
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
|
||||||
|
|
||||||
|
Centralizes the shared header set so every Comfy API request sends a consistent
|
||||||
|
set and new shared headers only need to be added in one place. Intended for
|
||||||
|
relative/cloud URLs resolved against ``default_base_url()``; because the result
|
||||||
|
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
**get_auth_header(node_cls),
|
||||||
|
"Comfy-Env": get_deploy_environment(),
|
||||||
|
"Comfy-Usage-Source": get_usage_source(node_cls),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def default_base_url() -> str:
|
def default_base_url() -> str:
|
||||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||||
|
|
||||||
|
|||||||
@ -19,12 +19,10 @@ from comfy import utils
|
|||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
|
||||||
from comfy.deploy_environment import get_deploy_environment
|
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
default_base_url,
|
default_base_url,
|
||||||
get_auth_header,
|
get_comfy_api_headers,
|
||||||
get_node_id,
|
get_node_id,
|
||||||
is_processing_interrupted,
|
is_processing_interrupted,
|
||||||
sleep_with_interrupt,
|
sleep_with_interrupt,
|
||||||
@ -645,8 +643,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
|
|
||||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
|
||||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
payload_headers.update(cfg.endpoint.headers)
|
payload_headers.update(cfg.endpoint.headers)
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from folder_paths import get_output_directory
|
|||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
default_base_url,
|
default_base_url,
|
||||||
get_auth_header,
|
get_comfy_api_headers,
|
||||||
is_processing_interrupted,
|
is_processing_interrupted,
|
||||||
sleep_with_interrupt,
|
sleep_with_interrupt,
|
||||||
to_aiohttp_url,
|
to_aiohttp_url,
|
||||||
@ -64,7 +64,7 @@ async def download_url_to_bytesio(
|
|||||||
if cls is None:
|
if cls is None:
|
||||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||||
headers = get_auth_header(cls)
|
headers = get_comfy_api_headers(cls)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
|
|||||||
@ -245,6 +245,11 @@ class KV_Attn_Input:
|
|||||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
kk, vv = self.cache[cache_key]
|
kk, vv = self.cache[cache_key]
|
||||||
|
|
||||||
|
# Fix batch size changing.
|
||||||
|
kk = comfy.utils.repeat_to_batch_size(kk, k.shape[0])
|
||||||
|
vv = comfy.utils.repeat_to_batch_size(vv, v.shape[0])
|
||||||
|
|
||||||
self.set_cache = False
|
self.set_cache = False
|
||||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||||
|
|
||||||
|
|||||||
@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode):
|
|||||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||||
|
io.Int.Input(
|
||||||
|
"bit_depth",
|
||||||
|
min=8,
|
||||||
|
max=10,
|
||||||
|
default=8,
|
||||||
|
step=2,
|
||||||
|
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
|
||||||
|
" banding, but some players and downstream nodes may not support it.",
|
||||||
|
optional=True,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Video.Output(),
|
io.Video.Output(),
|
||||||
@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
def execute(
|
||||||
|
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
|
||||||
|
) -> io.NodeOutput:
|
||||||
return io.NodeOutput(
|
return io.NodeOutput(
|
||||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
InputImpl.VideoFromComponents(
|
||||||
|
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
|
||||||
|
bit_depth=bit_depth,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
class GetVideoComponents(io.ComfyNode):
|
class GetVideoComponents(io.ComfyNode):
|
||||||
@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||||
display_name="Get Video Components",
|
display_name="Get Video Components",
|
||||||
category="video",
|
category="video",
|
||||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||||
],
|
],
|
||||||
@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
io.Image.Output(display_name="images"),
|
io.Image.Output(display_name="images"),
|
||||||
io.Audio.Output(display_name="audio"),
|
io.Audio.Output(display_name="audio"),
|
||||||
io.Float.Output(display_name="fps"),
|
io.Float.Output(display_name="fps"),
|
||||||
|
io.Int.Output(display_name="bit_depth"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||||
components = video.get_components()
|
components = video.get_components()
|
||||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
|
||||||
|
|
||||||
|
|
||||||
class LoadVideo(io.ComfyNode):
|
class LoadVideo(io.ComfyNode):
|
||||||
|
|||||||
@ -200,6 +200,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||||
|
if io.Hidden.comfy_usage_source.name in hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None)
|
||||||
else:
|
else:
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
@ -216,6 +218,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
if h[x] == "API_KEY_COMFY_ORG":
|
if h[x] == "API_KEY_COMFY_ORG":
|
||||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||||
|
if h[x] == "COMFY_USAGE_SOURCE":
|
||||||
|
input_data_all[x] = [extra_data.get("comfy_usage_source", None)]
|
||||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||||
return input_data_all, missing_keys, v3_data
|
return input_data_all, missing_keys, v3_data
|
||||||
|
|
||||||
|
|||||||
@ -973,6 +973,11 @@ class PromptServer():
|
|||||||
|
|
||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
|
|
||||||
|
if "comfy_usage_source" not in extra_data:
|
||||||
|
usage_source = request.headers.get("Comfy-Usage-Source")
|
||||||
|
if usage_source:
|
||||||
|
extra_data["comfy_usage_source"] = usage_source
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
outputs_to_execute = valid[2]
|
outputs_to_execute = valid[2]
|
||||||
sensitive = {}
|
sensitive = {}
|
||||||
|
|||||||
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
from fractions import Fraction
|
||||||
|
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||||
|
from comfy_api.latest._util.video_types import VideoComponents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def gradient_components():
|
||||||
|
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||||
|
width, height, frames = 64, 64, 3
|
||||||
|
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||||
|
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def src8(gradient_components, tmp_path_factory):
|
||||||
|
"""8-bit h264 mp4 (Create Video default)"""
|
||||||
|
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||||
|
VideoFromComponents(gradient_components).save_to(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def src10(gradient_components, tmp_path_factory):
|
||||||
|
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
|
||||||
|
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||||
|
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def probe(path):
|
||||||
|
"""(codec, pix_fmt, bit_depth) of the first video stream"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
|
||||||
|
|
||||||
|
|
||||||
|
def decoded_levels(path):
|
||||||
|
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
frame = next(container.decode(container.streams.video[0]))
|
||||||
|
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||||
|
|
||||||
|
|
||||||
|
def video_packet_bytes(path):
|
||||||
|
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_video_bit_depth(src8, src10):
|
||||||
|
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
|
||||||
|
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||||
|
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||||
|
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
|
||||||
|
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
|
||||||
|
for name, src in [("p8", src8), ("p10", src10)]:
|
||||||
|
path = str(tmp_path / f"{name}.mp4")
|
||||||
|
VideoFromFile(src).save_to(path)
|
||||||
|
assert probe(path) == probe(src)
|
||||||
|
assert video_packet_bytes(path) == video_packet_bytes(src)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
|
||||||
|
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
|
||||||
|
down = str(tmp_path / "down8.mp4")
|
||||||
|
VideoFromFile(src10).save_to(down, bit_depth=8)
|
||||||
|
assert probe(down) == ("h264", "yuv420p", 8)
|
||||||
|
|
||||||
|
up = str(tmp_path / "up10.mp4")
|
||||||
|
VideoFromFile(src8).save_to(up, bit_depth=10)
|
||||||
|
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||||
|
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
|
||||||
|
path = str(tmp_path / "trim.mp4")
|
||||||
|
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||||
|
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bit_depth(gradient_components, src8, src10):
|
||||||
|
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
|
||||||
|
assert VideoFromFile(src8).get_bit_depth() == 8
|
||||||
|
assert VideoFromFile(src10).get_bit_depth() == 10
|
||||||
|
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
|
||||||
|
assert VideoFromComponents(gradient_components).get_bit_depth() == 8
|
||||||
Loading…
Reference in New Issue
Block a user