diff --git a/README.md b/README.md index ee1024de5..786a14166 100644 --- a/README.md +++ b/README.md @@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step | Flag | Description | |------|-------------| | `--enable-manager` | Enable ComfyUI-Manager | -| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) | | `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cba0dfa34..e7ee0d5eb 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") +cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") manager_group = parser.add_mutually_exclusive_group() manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") -manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") +manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.") vram_group = parser.add_mutually_exclusive_group() @@ -249,6 +250,9 @@ else: if args.cache_ram is not None and len(args.cache_ram) > 2: parser.error("--cache-ram accepts at most two values: active GB and inactive GB") +if args.high_ram: + args.cache_classic = True + if args.windows_standalone_build: args.auto_launch = True @@ -258,6 +262,10 @@ if args.disable_auto_launch: if args.force_fp16: args.fp16_unet = True +# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'. +if args.enable_manager_legacy_ui: + args.enable_manager = True + # '--fast' is not provided, use an empty set if args.fast is None: diff --git a/comfy/ldm/ideogram4/model.py b/comfy/ldm/ideogram4/model.py index b86c65bf0..4ea5b8aaf 100644 --- a/comfy/ldm/ideogram4/model.py +++ b/comfy/ldm/ideogram4/model.py @@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module): self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device) self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device) - def forward(self, x): + def forward(self, x, dtype): x = x.to(torch.float32) scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min) emb = _sinusoidal_embedding(scaled, self.dim) - emb = emb.to(self.mlp_in.weight.dtype) + emb = emb.to(dtype) emb = F.silu(self.mlp_in(emb)) return self.mlp_out(emb) @@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module): x = x * output_image_mask h = self.input_proj(x) * output_image_mask - t_cond = self.t_embedding(t) + t_cond = self.t_embedding(t, dtype=x.dtype) if t.dim() == 1: t_cond = t_cond.unsqueeze(1) adaln_input = F.silu(self.adaln_proj(t_cond)) diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 82edc92da..e9ca5229d 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from comfy.ldm.lightricks.model import Timesteps from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.modules.attention import optimized_attention_masked import comfy.model_management import comfy.ldm.common_dit @@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis): if x.shape[1] == 0: return x - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x.shape).to(dtype=x.dtype) + return apply_rope1(x, freqs_cis) def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: diff --git a/comfy/model_management.py b/comfy/model_management.py index 55ddaab8e..b15d08ba1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -643,6 +643,8 @@ def free_pins(size, evict_active=False): return freed_total def ensure_pin_budget(size, evict_active=False): + if args.high_ram: + return True if args.fast_disk: shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY else: @@ -1496,6 +1498,8 @@ if not args.disable_pinned_memory: PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def pinned_hostbuf_size(size): + if args.high_ram: + return max(0, int(size * 2)) return max(0, int(min(size, MAX_PINNED_MEMORY) * 2)) def discard_cuda_async_error(): diff --git a/comfy/ops.py b/comfy/ops.py index 3c9912aae..3f088a962 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if pin is not None: cast_maybe_lowvram_patch([pin], dest, offload_stream) return - if signature is None: + if signature is None or args.high_ram: comfy.pinned_memory.pin_memory(m, subset=subset, size=size) pin = comfy.pinned_memory.get_pin(m, subset=subset) cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest) diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 8fff52c16..e2e99521f 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -27,10 +27,13 @@ class VideoInput(ABC): path: Union[str, IO[bytes]], format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, + bit_depth: int | None = None, ): """ Abstract method to save the video input to a file. + + bit_depth selects the encoded bit depth; None keeps the video's native depth. """ pass @@ -83,6 +86,14 @@ class VideoInput(ABC): components = self.get_components() return components.images.shape[2], components.images.shape[1] + def get_bit_depth(self) -> int: + """ + Returns the bit depth of the video (e.g. 8 or 10). + + Default implementation returns 8; subclasses report their real depth. + """ + return 8 + def get_duration(self) -> float: """ Returns the duration of the video in seconds. diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 4a12ff9c1..dfdf58515 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -52,6 +52,12 @@ def get_open_write_kwargs( return open_kwargs +def video_stream_bit_depth(stream) -> int: + if stream is None or stream.format is None or not stream.format.components: + return 8 + return max(component.bits for component in stream.format.components) + + class VideoFromFile(VideoInput): """ Class representing video input from a file. @@ -97,6 +103,13 @@ class VideoFromFile(VideoInput): return stream.width, stream.height raise ValueError(f"No video stream found in file '{self.__file}'") + def get_bit_depth(self) -> int: + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode="r") as container: + video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None + return video_stream_bit_depth(video_stream) + def get_duration(self) -> float: """ Returns the duration of the video in seconds. @@ -377,25 +390,32 @@ class VideoFromFile(VideoInput): format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, + bit_depth: int | None = None, ): if isinstance(self.__file, io.BytesIO): self.__file.seek(0) # Reset the BytesIO object to the beginning with av.open(self.__file, mode='r') as container: container_format = container.format.name - video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None + video_encoding = video_stream.codec.name if video_stream is not None else None + source_bit_depth = video_stream_bit_depth(video_stream) reuse_streams = True if format != VideoContainer.AUTO and format not in container_format.split(","): reuse_streams = False if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: reuse_streams = False + if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth: + reuse_streams = False if self.__start_time or self.__duration: reuse_streams = False if not reuse_streams: + if bit_depth is None: + bit_depth = source_bit_depth components = self.get_components_internal(container) video = VideoFromComponents(components) return video.save_to( - path, format=format, codec=codec, metadata=metadata + path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth, ) streams = container.streams @@ -451,8 +471,10 @@ class VideoFromComponents(VideoInput): Class representing video input from tensors. """ - def __init__(self, components: VideoComponents): + def __init__(self, components: VideoComponents, bit_depth: int = 8): self.__components = components + # Tensor components have no inherent bit depth; this is the depth used when encoding. + self.__bit_depth = bit_depth def get_components(self) -> VideoComponents: return VideoComponents( @@ -461,18 +483,26 @@ class VideoFromComponents(VideoInput): frame_rate=self.__components.frame_rate, ) + def get_bit_depth(self) -> int: + return self.__bit_depth + def save_to( self, path: str, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, + bit_depth: int | None = None, ): """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") + # None means "use the depth this video was created with" (CreateVideo's choice). + if bit_depth is None: + bit_depth = self.__bit_depth + is_10bit = bit_depth >= 10 extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value @@ -488,10 +518,11 @@ class VideoFromComponents(VideoInput): frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) # Create a video stream + pix_fmt = "yuv420p10le" if is_10bit else "yuv420p" video_stream = output.add_stream('h264', rate=frame_rate) video_stream.width = self.__components.images.shape[2] video_stream.height = self.__components.images.shape[1] - video_stream.pix_fmt = 'yuv420p' + video_stream.pix_fmt = pix_fmt # Create an audio stream audio_sample_rate = 1 @@ -505,9 +536,14 @@ class VideoFromComponents(VideoInput): # Encode video for i, frame in enumerate(self.__components.images): - img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) - frame = av.VideoFrame.from_ndarray(img, format='rgb24') - frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + if is_10bit: + # 16-bit RGB keeps float precision through the conversion to 10-bit YUV. + img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format="rgb48le") + else: + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format=pix_fmt) packet = video_stream.encode(frame) output.mux(packet) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 37614a4c3..012fae3ac 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1400,7 +1400,8 @@ class V3Data(TypedDict): class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, extra_pnginfo: Any, dynprompt: Any, - auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + auth_token_comfy_org: str, api_key_comfy_org: str, + comfy_usage_source: str = None, **kwargs): self.unique_id = unique_id """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" self.prompt = prompt @@ -1413,6 +1414,8 @@ class HiddenHolder: """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" self.api_key_comfy_org = api_key_comfy_org """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + self.comfy_usage_source = comfy_usage_source + """COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header.""" def __getattr__(self, key: str): '''If hidden variable not found, return None.''' @@ -1429,6 +1432,7 @@ class HiddenHolder: dynprompt=d.get(Hidden.dynprompt, None), auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None), api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), + comfy_usage_source=d.get(Hidden.comfy_usage_source, None), ) @classmethod @@ -1451,6 +1455,8 @@ class Hidden(str, Enum): """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" api_key_comfy_org = "API_KEY_COMFY_ORG" """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + comfy_usage_source = "COMFY_USAGE_SOURCE" + """COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header.""" @dataclass @@ -1654,6 +1660,8 @@ class Schema: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) + if Hidden.comfy_usage_source not in self.hidden: + self.hidden.append(Hidden.comfy_usage_source) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: if Hidden.prompt not in self.hidden: diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py index df6f2b845..6878aa6f0 100644 --- a/comfy_api_nodes/apis/runway.py +++ b/comfy_api_nodes/apis/runway.py @@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel): id: Optional[str] = Field(None, description='Task ID') -class RunwayTaskStatusEnum(str, Enum): - SUCCEEDED = 'SUCCEEDED' - RUNNING = 'RUNNING' - FAILED = 'FAILED' - PENDING = 'PENDING' - CANCELLED = 'CANCELLED' - THROTTLED = 'THROTTLED' - - class RunwayTaskStatusResponse(BaseModel): createdAt: datetime = Field(..., description='Task creation timestamp') id: str = Field(..., description='Task ID') @@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel): ge=0.0, le=1.0, ) - status: RunwayTaskStatusEnum + status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED") class Model4(str, Enum): @@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel): class RunwayTextToImageResponse(BaseModel): id: Optional[str] = Field(None, description='Task ID') + + +class RunwayAleph2IO: + """Custom socket types for chaining Aleph2 guidance images.""" + + KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME" + PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE" + + +# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to +# choose the request model below. The values match the Aleph2 keyframe union field names. +KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video +KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration + +# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`. +PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video +PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration + + +class RunwayAleph2KeyframeItem: + """A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``).""" + + def __init__(self, image, mode: str, value: float): + self.image = image + self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT + self.value = value + + +class RunwayAleph2KeyframeChain: + """An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes.""" + + def __init__(self): + self.items: list[RunwayAleph2KeyframeItem] = [] + + def add(self, item: RunwayAleph2KeyframeItem) -> None: + self.items.append(item) + + def clone(self) -> "RunwayAleph2KeyframeChain": + c = RunwayAleph2KeyframeChain() + c.items = list(self.items) + return c + + +class RunwayAleph2PromptImageItem: + """A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``).""" + + def __init__(self, image, mode: str, value: float): + self.image = image + self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION + self.value = value + + +class RunwayAleph2PromptImageChain: + """An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes.""" + + def __init__(self): + self.items: list[RunwayAleph2PromptImageItem] = [] + + def add(self, item: RunwayAleph2PromptImageItem) -> None: + self.items.append(item) + + def clone(self) -> "RunwayAleph2PromptImageChain": + c = RunwayAleph2PromptImageChain() + c.items = list(self.items) + return c + + +class RunwayAleph2KeyframeSeconds(BaseModel): + seconds: float = Field( + ..., + description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.", + ge=0.0, + ) + uri: str = Field(...) + + +class RunwayAleph2KeyframeAt(BaseModel): + at: float = Field( + ..., + description="Position as a fraction [0.0, 1.0] of the input video duration.", + ge=0.0, + le=1.0, + ) + uri: str = Field(...) + + +class RunwayAleph2TimestampPosition(BaseModel): + type: str = Field(default="timestamp") + timestampSeconds: float = Field( + ..., + description="Absolute timestamp in seconds from the start of the output video.", + ge=0.0, + ) + + +class RunwayAleph2RelativePosition(BaseModel): + type: str = Field(default="position") + positionPercentage: float = Field( + ..., + description="Position as a fraction [0.0, 1.0] of the total output video duration.", + ge=0.0, + le=1.0, + ) + + +class RunwayAleph2PromptImage(BaseModel): + position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition + uri: str = Field(...) + + +class RunwayAleph2ContentModeration(BaseModel): + publicFigureThreshold: str = Field( + ..., + description='When set to "low", the content moderation system is less strict about ' + 'recognizable public figures. One of "auto" or "low".', + ) + + +class RunwayAleph2Request(BaseModel): + model: str = Field(default="aleph2") + promptText: str = Field( + ..., + description="A non-empty string describing what should appear in the output.", + min_length=1, + max_length=1000, + ) + videoUri: str = Field(...) + seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295) + contentModeration: RunwayAleph2ContentModeration = Field(...) + keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field( + None, + description="Timed guidance images placed at specific points in the input video. Up to 5.", + ) + promptImage: list[RunwayAleph2PromptImage] | None = Field( + None, + description="Up to 5 image keyframes for guiding the edit at specific points in the output video.", + ) + + +class RunwayAleph2Response(BaseModel): + id: str | None = Field(None, description="Task ID") diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index e138fafa9..090154afb 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -289,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", ), ) @@ -357,7 +357,7 @@ class BriaVideoGreenScreen(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", ), ) @@ -433,7 +433,7 @@ class BriaVideoReplaceBackground(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", ), ) @@ -452,7 +452,10 @@ class BriaVideoReplaceBackground(IO.ComfyNode): validate_video_duration(background_video, max_duration=60.0) background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background") else: - background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background") + # Bria's replace_background 500s on RGBA, so drop the alpha channel before upload. + background_url = await upload_image_to_comfyapi( + cls, background_image[:, :, :, :3], wait_label="Uploading background" + ) response = await sync_op( cls, ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"), @@ -530,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", ), ) @@ -571,7 +574,7 @@ class BriaExtension(ComfyExtension): BriaRemoveImageBackground, BriaRemoveVideoBackground, BriaVideoGreenScreen, - # BriaVideoReplaceBackground, # server returns Status 500 when we pass background video + BriaVideoReplaceBackground, BriaTransparentVideoBackground, ] diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index b9c5c81a1..013a193d9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import ( Model4, ReferenceImage, RunwayTextToImageAspectRatioEnum, + RunwayAleph2IO, + RunwayAleph2KeyframeChain, + RunwayAleph2KeyframeItem, + RunwayAleph2PromptImageChain, + RunwayAleph2PromptImageItem, + RunwayAleph2Request, + RunwayAleph2Response, + RunwayAleph2KeyframeSeconds, + RunwayAleph2KeyframeAt, + RunwayAleph2PromptImage, + RunwayAleph2TimestampPosition, + RunwayAleph2RelativePosition, + RunwayAleph2ContentModeration, + KEYFRAME_MODE_SECONDS, + KEYFRAME_MODE_AT, + PROMPT_IMAGE_MODE_TIMESTAMP, + PROMPT_IMAGE_MODE_POSITION, ) from comfy_api_nodes.util import ( image_tensor_pair_to_batch, validate_string, validate_image_dimensions, validate_image_aspect_ratio, + validate_video_duration, upload_images_to_comfyapi, + upload_image_to_comfyapi, + upload_video_to_comfyapi, download_url_to_video_output, download_url_to_image_tensor, ApiEndpoint, @@ -45,6 +65,7 @@ from comfy_api_nodes.util import ( ) PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" +PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" PATH_GET_TASK_STATUS = "/proxy/runway/tasks" @@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256 AVERAGE_DURATION_T2I_SECONDS = 41 -class RunwayApiError(Exception): - """Base exception for Runway API errors.""" - - pass - - class RunwayGen4TurboAspectRatio(str, Enum): """Aspect ratios supported for Image to Video API when using gen4_turbo model.""" @@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: return None -def extract_progress_from_task_status( - response: TaskStatusResponse, -) -> float | None: - if hasattr(response, "progress") and response.progress is not None: - return response.progress * 100 - return None - - def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: @@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: async def get_response( cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None ) -> TaskStatusResponse: - """Poll the task status until it is finished then get the response.""" return await poll_op( cls, ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), response_model=TaskStatusResponse, - status_extractor=lambda r: r.status.value, + status_extractor=lambda r: r.status, estimated_duration=estimated_duration, - progress_extractor=extract_progress_from_task_status, + progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None, ) @@ -127,7 +133,7 @@ async def generate_video( final_response = await get_response(cls, initial_response.id, estimated_duration) if not final_response.output: - raise RunwayApiError("Runway task succeeded but no video data found in response.") + raise ValueError("Runway task succeeded but no video data found in response.") video_url = get_video_url_from_task_status(final_response) return await download_url_to_video_output(video_url) @@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): mime_type="image/png", ) if len(download_urls) != 2: - raise RunwayApiError("Failed to upload one or more images to comfy api.") + raise ValueError("Failed to upload one or more images to comfy api.") return IO.NodeOutput( await generate_video( @@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode): estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) if not final_response.output: - raise RunwayApiError("Runway task succeeded but no image data found in response.") + raise ValueError("Runway task succeeded but no image data found in response.") return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response))) +_TIMING_ABSOLUTE = "Absolute time (seconds)" +_TIMING_FRACTION = "Fraction of duration (0.0-1.0)" + + +class RunwayAleph2KeyframeNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayAleph2KeyframeNode", + display_name="Runway Aleph2 Keyframe", + category="partner/video/Runway", + description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 " + "steers the edit at that point of your footage. Connect this to the 'keyframes' input of " + "the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional " + "'keyframes' input below.", + inputs=[ + IO.Image.Input( + "image", + tooltip="The guidance image to apply at the chosen moment of the input video.", + ), + IO.DynamicCombo.Input( + "timing", + options=[ + IO.DynamicCombo.Option( + _TIMING_ABSOLUTE, + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=30.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from start of the input video where this image applies.", + ), + ], + ), + IO.DynamicCombo.Option( + _TIMING_FRACTION, + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the input video this image applies, " + "as a fraction of its duration (0.0 = start, 1.0 = end).", + ), + ], + ), + ], + tooltip="How to place this image on the input video's timeline.", + ), + IO.Custom(RunwayAleph2IO.KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Optional earlier keyframes to chain with this one.", + ), + ], + outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")], + ) + + @classmethod + def execute( + cls, + image: Input.Image, + timing: dict, + keyframes: RunwayAleph2KeyframeChain | None = None, + ) -> IO.NodeOutput: + chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain() + if timing["timing"] == _TIMING_ABSOLUTE: + mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"]) + else: + mode, value = KEYFRAME_MODE_AT, float(timing["fraction"]) + chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class RunwayAleph2PromptImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayAleph2PromptImageNode", + display_name="Runway Aleph2 Prompt Image", + category="partner/video/Runway", + description="Anchor a guidance image to a moment of the output (result) video, to guide what " + "the edited video looks like at that point. Connect this to the 'prompt_images' input of the " + "Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional " + "'prompt_images' input below.", + inputs=[ + IO.Image.Input( + "image", + tooltip="The guidance image to place at the chosen moment of the output video.", + ), + IO.DynamicCombo.Input( + "position", + options=[ + IO.DynamicCombo.Option( + _TIMING_ABSOLUTE, + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=30.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from start of the output video where this image applies.", + ), + ], + ), + IO.DynamicCombo.Option( + _TIMING_FRACTION, + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the output video this image applies, " + "as a fraction of its duration (0.0 = start, 1.0 = end).", + ), + ], + ), + ], + tooltip="How to place this image on the output video's timeline.", + ), + IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input( + "prompt_images", + optional=True, + tooltip="Optional earlier prompt images to chain with this one.", + ), + ], + outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")], + ) + + @classmethod + def execute( + cls, + image: Input.Image, + position: dict, + prompt_images: RunwayAleph2PromptImageChain | None = None, + ) -> IO.NodeOutput: + chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain() + if position["position"] == _TIMING_ABSOLUTE: + mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"]) + else: + mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"]) + chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class RunwayAleph2VideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayAleph2VideoToVideoNode", + display_name="Runway Aleph2 Video to Video", + category="partner/video/Runway", + description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms " + "your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping " + "the original motion and timing; the output resolution matches the input video, which must be " + "2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to " + "the input video) or prompt images (anchored to the output video) - use one or the other, not both.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes what should appear in the output (1-1000 characters).", + ), + IO.Video.Input( + "video", + tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967295, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed for generation", + ), + IO.Combo.Input( + "public_figure_threshold", + options=["auto", "low"], + default="low", + tooltip="Content moderation for recognizable public figures.", + ), + IO.Custom(RunwayAleph2IO.KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). " + "Use keyframes or prompt images, not both.", + ), + IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input( + "prompt_images", + optional=True, + tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). " + "Use keyframes or prompt images, not both.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + video: Input.Video, + seed: int, + public_figure_threshold: str = "low", + keyframes: RunwayAleph2KeyframeChain | None = None, + prompt_images: RunwayAleph2PromptImageChain | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=1000) + validate_video_duration( + video, + min_duration=2.0, + max_duration=30.0, + ) + try: + fps = float(video.get_frame_rate()) + except Exception: + fps = None + if fps is not None and fps > 30.0 + 0.01: + raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.") + + if (keyframes and keyframes.items) and (prompt_images and prompt_images.items): + raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.") + + video_duration: float | None = None + try: + video_duration = video.get_duration() + except Exception: + video_duration = None + + def _check_seconds(value: float, label: str) -> None: + if video_duration is not None and value > video_duration + 0.0001: + raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).") + + video_url = await upload_video_to_comfyapi(cls, video) + + keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = [] + if keyframes is not None: + if len(keyframes.items) > 5: + raise ValueError("Aleph2 supports at most 5 keyframes.") + for item in keyframes.items: + image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png") + if item.mode == KEYFRAME_MODE_SECONDS: + _check_seconds(item.value, "Keyframe timestamp") + keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url)) + else: + keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url)) + + prompt_image_models: list[RunwayAleph2PromptImage] = [] + if prompt_images is not None: + if len(prompt_images.items) > 5: + raise ValueError("Aleph2 supports at most 5 prompt images.") + for item in prompt_images.items: + image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png") + position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition + if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP: + _check_seconds(item.value, "Prompt image timestamp") + position = RunwayAleph2TimestampPosition(timestampSeconds=item.value) + else: + position = RunwayAleph2RelativePosition(positionPercentage=item.value) + prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url)) + + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"), + response_model=RunwayAleph2Response, + data=RunwayAleph2Request( + promptText=prompt, + videoUri=video_url, + seed=seed, + contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold), + keyframes=keyframe_models or None, + promptImage=prompt_image_models or None, + ), + ) + + final_response = await get_response(cls, initial_response.id) + if not final_response.output: + raise ValueError("Runway task succeeded but no video data found in response.") + + return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response))) + + class RunwayExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension): RunwayImageToVideoNodeGen3a, RunwayImageToVideoNodeGen4, RunwayTextToImageNode, + RunwayAleph2VideoToVideoNode, + RunwayAleph2KeyframeNode, + RunwayAleph2PromptImageNode, ] diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py index 9ce896ed0..24a9a0b06 100644 --- a/comfy_api_nodes/nodes_sonilo.py +++ b/comfy_api_nodes/nodes_sonilo.py @@ -16,7 +16,7 @@ from comfy_api_nodes.util import ( ) from comfy_api_nodes.util._helpers import ( default_base_url, - get_auth_header, + get_comfy_api_headers, get_node_id, is_processing_interrupted, ) @@ -174,8 +174,7 @@ async def _stream_sonilo_music( """POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes.""" url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/")) - headers: dict[str, str] = {} - headers.update(get_auth_header(cls)) + headers = get_comfy_api_headers(cls) headers.update(endpoint.headers) node_id = get_node_id(cls) diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 648defe3d..83cf7b001 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -9,6 +9,7 @@ from io import BytesIO from yarl import URL from comfy.cli_args import args +from comfy.deploy_environment import get_deploy_environment from comfy.model_management import processing_interrupted from comfy_api.latest import IO @@ -35,6 +36,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: return {} +def get_usage_source(node_cls: type[IO.ComfyNode]) -> str: + """Source of the prompt that triggered this API node. + + Defaults to "comfyui-api" when the submitting client didn't identify itself, + i.e. a direct API call to this server. + """ + return node_cls.hidden.comfy_usage_source or "comfyui-api" + + +def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + """Common headers (auth, deploy environment, usage source) for Comfy API requests. + + Centralizes the shared header set so every Comfy API request sends a consistent + set and new shared headers only need to be added in one place. Intended for + relative/cloud URLs resolved against ``default_base_url()``; because the result + includes auth, callers must not attach it to arbitrary absolute/presigned URLs. + """ + return { + **get_auth_header(node_cls), + "Comfy-Env": get_deploy_environment(), + "Comfy-Usage-Source": get_usage_source(node_cls), + } + + def default_base_url() -> str: return getattr(args, "comfy_api_base", "https://api.comfy.org") diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 57c501724..adcde7bcb 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -19,12 +19,10 @@ from comfy import utils from comfy_api.latest import IO from server import PromptServer -from comfy.deploy_environment import get_deploy_environment - from . import request_logger from ._helpers import ( default_base_url, - get_auth_header, + get_comfy_api_headers, get_node_id, is_processing_interrupted, sleep_with_interrupt, @@ -645,8 +643,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? - payload_headers.update(get_auth_header(cfg.node_cls)) - payload_headers["Comfy-Env"] = get_deploy_environment() + payload_headers.update(get_comfy_api_headers(cfg.node_cls)) if cfg.endpoint.headers: payload_headers.update(cfg.endpoint.headers) diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index aa588d038..0ec3c6e66 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -17,7 +17,7 @@ from folder_paths import get_output_directory from . import request_logger from ._helpers import ( default_base_url, - get_auth_header, + get_comfy_api_headers, is_processing_interrupted, sleep_with_interrupt, to_aiohttp_url, @@ -64,7 +64,7 @@ async def download_url_to_bytesio( if cls is None: raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) - headers = get_auth_header(cls) + headers = get_comfy_api_headers(cls) while True: attempt += 1 diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index afc663b22..ef1757ae5 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -245,6 +245,11 @@ class KV_Attn_Input: cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) if cache_key in self.cache: kk, vv = self.cache[cache_key] + + # Fix batch size changing. + kk = comfy.utils.repeat_to_batch_size(kk, k.shape[0]) + vv = comfy.utils.repeat_to_batch_size(vv, v.shape[0]) + self.set_cache = False return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 6f6c416a6..050a897dd 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode): io.Image.Input("images", tooltip="The images to create a video from."), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + io.Int.Input( + "bit_depth", + min=8, + max=10, + default=8, + step=2, + tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less" + " banding, but some players and downstream nodes may not support it.", + optional=True, + display_mode=io.NumberDisplay.number, + ), ], outputs=[ io.Video.Output(), @@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode): ) @classmethod - def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: + def execute( + cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8, + ) -> io.NodeOutput: return io.NodeOutput( - InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + InputImpl.VideoFromComponents( + Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)), + bit_depth=bit_depth, + ) ) class GetVideoComponents(io.ComfyNode): @@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode): search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="video", - description="Extracts all components from a video: frames, audio, and framerate.", + description="Extracts all components from a video: frames, audio, framerate, and bit depth.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), ], @@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode): io.Image.Output(display_name="images"), io.Audio.Output(display_name="audio"), io.Float.Output(display_name="fps"), + io.Int.Output(display_name="bit_depth"), ], ) @classmethod def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() - return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) + return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth()) class LoadVideo(io.ComfyNode): diff --git a/execution.py b/execution.py index e6c6f39d6..9e16e451d 100644 --- a/execution.py +++ b/execution.py @@ -200,6 +200,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + if io.Hidden.comfy_usage_source.name in hidden: + hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None) else: if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -216,6 +218,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] + if h[x] == "COMFY_USAGE_SOURCE": + input_data_all[x] = [extra_data.get("comfy_usage_source", None)] v3_data["hidden_inputs"] = hidden_inputs_v3 return input_data_all, missing_keys, v3_data diff --git a/server.py b/server.py index 2d5da23fc..6b0029adf 100644 --- a/server.py +++ b/server.py @@ -973,6 +973,11 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] + + if "comfy_usage_source" not in extra_data: + usage_source = request.headers.get("Comfy-Usage-Source") + if usage_source: + extra_data["comfy_usage_source"] = usage_source if valid[0]: outputs_to_execute = valid[2] sensitive = {} diff --git a/tests-unit/comfy_api_test/video_bit_depth_test.py b/tests-unit/comfy_api_test/video_bit_depth_test.py new file mode 100644 index 000000000..6c7bc9163 --- /dev/null +++ b/tests-unit/comfy_api_test/video_bit_depth_test.py @@ -0,0 +1,93 @@ +import pytest +import torch +import av +import numpy as np +from fractions import Fraction +from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents +from comfy_api.latest._util.video_types import VideoComponents + + +@pytest.fixture(scope="module") +def gradient_components(): + """Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth""" + width, height, frames = 64, 64, 3 + ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3) + return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30)) + + +@pytest.fixture(scope="module") +def src8(gradient_components, tmp_path_factory): + """8-bit h264 mp4 (Create Video default)""" + path = str(tmp_path_factory.mktemp("video") / "src8.mp4") + VideoFromComponents(gradient_components).save_to(path) + return path + + +@pytest.fixture(scope="module") +def src10(gradient_components, tmp_path_factory): + """10-bit h264 mp4 (Create Video with bit_depth=10)""" + path = str(tmp_path_factory.mktemp("video") / "src10.mp4") + VideoFromComponents(gradient_components, bit_depth=10).save_to(path) + return path + + +def probe(path): + """(codec, pix_fmt, bit_depth) of the first video stream""" + with av.open(path) as container: + stream = container.streams.video[0] + return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components)) + + +def decoded_levels(path): + """Unique tonal levels in the first decoded frame (banding measure)""" + with av.open(path) as container: + frame = next(container.decode(container.streams.video[0])) + return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0])) + + +def video_packet_bytes(path): + """Raw video packet payloads; identical to the source's only for a true remux""" + with av.open(path) as container: + return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size] + + +def test_create_video_bit_depth(src8, src10): + """Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding""" + assert probe(src8) == ("h264", "yuv420p", 8) + assert probe(src10) == ("h264", "yuv420p10le", 10) + assert decoded_levels(src10) > 2 * decoded_levels(src8) + + +def test_save_auto_keeps_source_depth(src8, src10, tmp_path): + """Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte""" + for name, src in [("p8", src8), ("p10", src10)]: + path = str(tmp_path / f"{name}.mp4") + VideoFromFile(src).save_to(path) + assert probe(path) == probe(src) + assert video_packet_bytes(path) == video_packet_bytes(src) + + +def test_save_explicit_depth_reencodes(src8, src10, tmp_path): + """An explicit bit_depth different from the source forces a re-encode to that depth""" + down = str(tmp_path / "down8.mp4") + VideoFromFile(src10).save_to(down, bit_depth=8) + assert probe(down) == ("h264", "yuv420p", 8) + + up = str(tmp_path / "up10.mp4") + VideoFromFile(src8).save_to(up, bit_depth=10) + assert probe(up) == ("h264", "yuv420p10le", 10) + + +def test_trim_keeps_source_depth(src10, tmp_path): + """Video Slice re-encodes (trim) but preserves the source's 10-bit depth""" + path = str(tmp_path / "trim.mp4") + VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path) + assert probe(path) == ("h264", "yuv420p10le", 10) + + +def test_get_bit_depth(gradient_components, src8, src10): + """get_bit_depth reports a video's depth (backs the Get Video Components output)""" + assert VideoFromFile(src8).get_bit_depth() == 8 + assert VideoFromFile(src10).get_bit_depth() == 10 + assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10 + assert VideoFromComponents(gradient_components).get_bit_depth() == 8