mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-15 04:19:43 +08:00
117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
import asyncio
|
|
import contextlib
|
|
import os
|
|
import re
|
|
import time
|
|
from collections.abc import Callable
|
|
from io import BytesIO
|
|
|
|
from yarl import URL
|
|
|
|
from comfy.cli_args import args
|
|
from comfy.deploy_environment import get_deploy_environment
|
|
from comfy.model_management import processing_interrupted
|
|
from comfy_api.latest import IO
|
|
|
|
from .common_exceptions import ProcessingInterrupted
|
|
|
|
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
|
|
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
|
|
|
|
|
|
def is_processing_interrupted() -> bool:
|
|
"""Return True if user/runtime requested interruption."""
|
|
return processing_interrupted()
|
|
|
|
|
|
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
|
|
return node_cls.hidden.unique_id
|
|
|
|
|
|
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
|
if node_cls.hidden.auth_token_comfy_org:
|
|
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
|
|
if node_cls.hidden.api_key_comfy_org:
|
|
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
|
|
return {}
|
|
|
|
|
|
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
|
|
"""Source of the prompt that triggered this API node.
|
|
|
|
Defaults to "comfyui-api" when the submitting client didn't identify itself,
|
|
i.e. a direct API call to this server.
|
|
"""
|
|
return node_cls.hidden.comfy_usage_source or "comfyui-api"
|
|
|
|
|
|
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
|
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
|
|
|
|
Centralizes the shared header set so every Comfy API request sends a consistent
|
|
set and new shared headers only need to be added in one place. Intended for
|
|
relative/cloud URLs resolved against ``default_base_url()``; because the result
|
|
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
|
|
"""
|
|
return {
|
|
**get_auth_header(node_cls),
|
|
"Comfy-Env": get_deploy_environment(),
|
|
"Comfy-Usage-Source": get_usage_source(node_cls),
|
|
}
|
|
|
|
|
|
def default_base_url() -> str:
|
|
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
|
|
|
|
|
async def sleep_with_interrupt(
|
|
seconds: float,
|
|
node_cls: type[IO.ComfyNode] | None,
|
|
label: str | None = None,
|
|
start_ts: float | None = None,
|
|
estimated_total: int | None = None,
|
|
*,
|
|
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
|
|
):
|
|
"""
|
|
Sleep in 1s slices while:
|
|
- Checking for interruption (raises ProcessingInterrupted).
|
|
- Optionally emitting time progress via display_callback (if provided).
|
|
"""
|
|
end = time.monotonic() + seconds
|
|
while True:
|
|
if is_processing_interrupted():
|
|
raise ProcessingInterrupted("Task cancelled")
|
|
now = time.monotonic()
|
|
if start_ts is not None and label and display_callback:
|
|
with contextlib.suppress(Exception):
|
|
display_callback(node_cls, label, int(now - start_ts), estimated_total)
|
|
if now >= end:
|
|
break
|
|
await asyncio.sleep(min(1.0, end - now))
|
|
|
|
|
|
def mimetype_to_extension(mime_type: str) -> str:
|
|
"""Converts a MIME type to a file extension."""
|
|
return mime_type.split("/")[-1].lower()
|
|
|
|
|
|
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
|
if isinstance(path_or_object, str):
|
|
return os.path.getsize(path_or_object)
|
|
return len(path_or_object.getvalue())
|
|
|
|
|
|
def to_aiohttp_url(url: str) -> URL:
|
|
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
|
|
escape and no malformed '%' sequences) and contains no raw whitespace/control
|
|
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
|
|
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
|
|
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
|
|
# Avoid encoded=True if URL contains raw whitespace/control chars
|
|
return URL(url)
|
|
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
|
|
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
|
|
return URL(url, encoded=True)
|
|
return URL(url)
|