This commit is contained in:
Christian Byrne 2026-04-19 04:00:07 -07:00 committed by GitHub
commit 022c4d1d41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 346 additions and 73 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import contextlib import contextlib
import json import json
import logging import logging
import os
import time import time
import uuid import uuid
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
@ -32,6 +33,30 @@ from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInte
M = TypeVar("M", bound=BaseModel) M = TypeVar("M", bound=BaseModel)
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
RETRY_DEFAULTS = _RetryDefaults()
class ApiEndpoint: class ApiEndpoint:
def __init__( def __init__(
self, self,
@ -78,11 +103,21 @@ class _PollUIState:
price: float | None = None price: float | None = None
estimated_duration: int | None = None estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued) active_since: float | None = (
None # start time of current active interval (None if queued)
)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -98,9 +133,9 @@ async def sync_op(
content_type: str = "application/json", content_type: str = "application/json",
timeout: float = 3600.0, timeout: float = 3600.0,
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
max_retries: int = 3, max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = 1.0, retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = 2.0, retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str = "Waiting for server", wait_label: str = "Waiting for server",
estimated_duration: int | None = None, estimated_duration: int | None = None,
final_label_on_success: str | None = "Completed", final_label_on_success: str | None = "Completed",
@ -131,7 +166,9 @@ async def sync_op(
is_rate_limited=is_rate_limited, is_rate_limited=is_rate_limited,
) )
if not isinstance(raw, dict): if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw) return _validate_or_raise(response_model, raw)
@ -178,7 +215,9 @@ async def poll_op(
cancel_timeout=cancel_timeout, cancel_timeout=cancel_timeout,
) )
if not isinstance(raw, dict): if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw) return _validate_or_raise(response_model, raw)
@ -192,9 +231,9 @@ async def sync_op_raw(
content_type: str = "application/json", content_type: str = "application/json",
timeout: float = 3600.0, timeout: float = 3600.0,
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
max_retries: int = 3, max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = 1.0, retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = 2.0, retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str = "Waiting for server", wait_label: str = "Waiting for server",
estimated_duration: int | None = None, estimated_duration: int | None = None,
as_binary: bool = False, as_binary: bool = False,
@ -269,9 +308,15 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint. Returns the final JSON response from the poll endpoint.
""" """
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) completed_states = _normalize_statuses(
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) COMPLETED_STATUSES if completed_statuses is None else completed_statuses
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) )
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
started = time.monotonic() started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls consumed_attempts = 0 # counts only non-queued polls
@ -289,7 +334,9 @@ async def poll_op_raw(
break break
now = time.monotonic() now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + ( proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0 (now - state.active_since)
if state.active_since is not None
else 0.0
) )
_display_time_progress( _display_time_progress(
cls, cls,
@ -361,11 +408,15 @@ async def poll_op_raw(
is_queued = status in queued_states is_queued = status in queued_states
if is_queued: if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None state.active_since = None
else: else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts state.active_since = now_ts
state.is_queued = is_queued state.is_queued = is_queued
@ -442,7 +493,9 @@ def _display_text(
) -> None: ) -> None:
display_lines: list[str] = [] display_lines: list[str] = []
if status: if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
if price is not None: if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0": if p != "0":
@ -450,7 +503,9 @@ def _display_text(
if text is not None: if text is not None:
display_lines.append(text) display_lines.append(text)
if display_lines: if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
def _display_time_progress( def _display_time_progress(
@ -464,7 +519,11 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None, processing_elapsed_seconds: int | None = None,
) -> None: ) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False: if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
remaining = max(0, int(estimated_total) - int(pe)) remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else: else:
@ -503,7 +562,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])") raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
params = dict(endpoint_params or {}) params = dict(endpoint_params or {})
if method.upper() == "GET" and data: if method.upper() == "GET" and data:
for k, v in data.items(): for k, v in data.items():
@ -566,8 +627,14 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0] filename = file_obj[0]
else: else:
filename = getattr(file_obj, "name", field_name) filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")}) file_fields.append(
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} {"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
if content_type == "application/x-www-form-urlencoded": if content_type == "application/x-www-form-urlencoded":
return data or {} return data or {}
return data or {} return data or {}
@ -581,7 +648,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
async def _monitor(stop_evt: asyncio.Event, start_ts: float): async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption.""" """Every second: update elapsed time and signal interruption."""
@ -591,13 +660,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return return
if cfg.monitor_progress: if cfg.monitor_progress:
_display_time_progress( _display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
) )
await asyncio.sleep(1.0) await asyncio.sleep(1.0)
except asyncio.CancelledError: except asyncio.CancelledError:
return # normal shutdown return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
attempt = 0 attempt = 0
delay = cfg.retry_delay delay = cfg.retry_delay
rate_limit_attempts = 0 rate_limit_attempts = 0
@ -614,7 +690,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls)) payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers: if cfg.endpoint.headers:
@ -623,7 +701,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers} payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET": if method == "GET":
payload_headers.pop("Content-Type", None) payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
try: try:
if cfg.monitor_progress: if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -637,16 +717,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data: if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data) form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData): if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData") raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
else: else:
form = aiohttp.FormData(default_to_multipart=True) form = aiohttp.FormData(default_to_multipart=True)
if cfg.data: if cfg.data:
for k, v in cfg.data.items(): for k, v in cfg.data.items():
if v is None: if v is None:
continue continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
if cfg.files: if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
for field_name, file_obj in file_iter: for field_name, file_obj in file_iter:
if file_obj is None: if file_obj is None:
continue continue
@ -660,9 +747,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO): if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
file_value.seek(0) file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type) form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
payload_kw["data"] = form payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
payload_headers["Content-Type"] = "application/x-www-form-urlencoded" payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {} payload_kw["data"] = cfg.data or {}
elif method != "GET": elif method != "GET":
@ -685,7 +780,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task} tasks = {req_task}
if monitor_task: if monitor_task:
tasks.add(monitor_task) tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task and monitor_task in done: if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort # Interrupted cancel the request and abort
@ -705,7 +802,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0 wait_time = 0.0
retry_label = "" retry_label = ""
is_rl = resp.status == 429 or ( is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body) cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
) )
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit: if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1 rate_limit_attempts += 1
@ -713,7 +811,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}" retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries: elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
wait_time = delay wait_time = delay
delay *= cfg.retry_backoff delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}" retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -743,7 +844,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None, cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None,
cfg.estimated_total, cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None, display_callback=_display_time_progress
if cfg.monitor_progress
else None,
) )
continue continue
msg = _friendly_http_message(resp.status, body) msg = _friendly_http_message(resp.status, body)
@ -770,7 +873,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled") raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress: if cfg.monitor_progress:
_display_time_progress( _display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
) )
bytes_payload = bytes(buff) bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()} resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -800,9 +906,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {} payload = json.loads(text) if text else {}
except json.JSONDecodeError: except json.JSONDecodeError:
payload = {"_raw": text} payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text response_content_to_log = (
payload if isinstance(payload, dict) else text
)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
operation_succeeded = True operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time) final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response( request_logger.log_request_response(
@ -844,7 +956,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None, cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None,
cfg.estimated_total, cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None, display_callback=_display_time_progress
if cfg.monitor_progress
else None,
) )
delay *= cfg.retry_backoff delay *= cfg.retry_backoff
continue continue
@ -885,7 +999,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess: if sess:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await sess.close() await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
_display_time_progress( _display_time_progress(
cfg.node_cls, cfg.node_cls,
status=cfg.final_label_on_success, status=cfg.final_label_on_success,

View File

@ -22,7 +22,7 @@ from ._helpers import (
sleep_with_interrupt, sleep_with_interrupt,
to_aiohttp_url, to_aiohttp_url,
) )
from .client import _diagnose_connectivity from .client import RETRY_DEFAULTS, _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor from .conversions import bytesio_to_image_tensor
@ -34,9 +34,9 @@ async def download_url_to_bytesio(
dest: BytesIO | IO[bytes] | str | Path | None, dest: BytesIO | IO[bytes] | str | Path | None,
*, *,
timeout: float | None = None, timeout: float | None = None,
max_retries: int = 5, max_retries: int = max(5, RETRY_DEFAULTS.max_retries),
retry_delay: float = 1.0, retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = 2.0, retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
cls: type[COMFY_IO.ComfyNode] = None, cls: type[COMFY_IO.ComfyNode] = None,
) -> None: ) -> None:
"""Stream-download a URL to `dest`. """Stream-download a URL to `dest`.
@ -53,7 +53,9 @@ async def download_url_to_bytesio(
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
""" """
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") raise ValueError(
"dest must be a path (str|Path) or a binary-writable object providing .write()."
)
attempt = 0 attempt = 0
delay = retry_delay delay = retry_delay
@ -62,7 +64,9 @@ async def download_url_to_bytesio(
parsed_url = urlparse(url) parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None: if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") raise ValueError(
"For relative 'cloud' paths, the `cls` parameter is required."
)
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls) headers = get_auth_header(cls)
@ -80,7 +84,9 @@ async def download_url_to_bytesio(
try: try:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) request_logger.log_request_response(
operation_id=op_id, request_method="GET", request_url=url
)
session = aiohttp.ClientSession(timeout=timeout_cfg) session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event() stop_evt = asyncio.Event()
@ -96,8 +102,12 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) req_task = asyncio.create_task(
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) session.get(to_aiohttp_url(url), headers=headers)
)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task in done and req_task in pending: if monitor_task in done and req_task in pending:
req_task.cancel() req_task.cancel()
@ -117,7 +127,11 @@ async def download_url_to_bytesio(
body = await resp.json() body = await resp.json()
except (ContentTypeError, ValueError): except (ContentTypeError, ValueError):
text = await resp.text() text = await resp.text()
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" body = (
text
if len(text) <= 4096
else f"[text {len(text)} bytes]"
)
request_logger.log_request_response( request_logger.log_request_response(
operation_id=op_id, operation_id=op_id,
request_method="GET", request_method="GET",
@ -146,7 +160,9 @@ async def download_url_to_bytesio(
written = 0 written = 0
while True: while True:
try: try:
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) chunk = await asyncio.wait_for(
resp.content.read(1024 * 1024), timeout=1.0
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
chunk = b"" chunk = b""
except asyncio.CancelledError: except asyncio.CancelledError:
@ -195,7 +211,9 @@ async def download_url_to_bytesio(
raise LocalNetworkError( raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again." "Unable to connect to the network. Please check your internet connection and try again."
) from e ) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e raise ApiServerError(
"The remote service appears unreachable at this time."
) from e
finally: finally:
if stop_evt is not None: if stop_evt is not None:
stop_evt.set() stop_evt.set()
@ -237,7 +255,9 @@ async def download_url_to_video_output(
) -> InputImpl.VideoFromFile: ) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.""" """Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO() result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) await download_url_to_bytesio(
video_url, result, timeout=timeout, max_retries=max_retries, cls=cls
)
return InputImpl.VideoFromFile(result) return InputImpl.VideoFromFile(result)
@ -256,7 +276,11 @@ async def download_url_as_bytesio(
def _generate_operation_id(method: str, url: str, attempt: int) -> str: def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try: try:
parsed = urlparse(url) parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download")
.strip("/")
.replace("/", "_")
)
except Exception: except Exception:
slug = "download" slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@ -15,6 +15,7 @@ from comfy_api.latest import IO, Input, Types
from . import request_logger from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import ( from .client import (
RETRY_DEFAULTS,
ApiEndpoint, ApiEndpoint,
_diagnose_connectivity, _diagnose_connectivity,
_display_time_progress, _display_time_progress,
@ -77,13 +78,17 @@ async def upload_images_to_comfyapi(
for idx in range(num_to_upload): for idx in range(num_to_upload):
tensor = tensors[idx] tensor = tensors[idx]
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) img_io = tensor_to_bytesio(
tensor, total_pixels=total_pixels, mime_type=mime_type
)
effective_label = wait_label effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1: if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) url = await upload_file_to_comfyapi(
cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts
)
download_urls.append(url) download_urls.append(url)
return download_urls return download_urls
@ -125,8 +130,12 @@ async def upload_audio_to_comfyapi(
sample_rate: int = audio["sample_rate"] sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"] waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) audio_bytes_io = audio_ndarray_to_bytesio(
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(
cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type
)
async def upload_video_to_comfyapi( async def upload_video_to_comfyapi(
@ -161,7 +170,9 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec) video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0) video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) return await upload_file_to_comfyapi(
cls, video_bytes_io, filename, upload_mime_type, wait_label
)
_3D_MIME_TYPES = { _3D_MIME_TYPES = {
@ -197,7 +208,9 @@ async def upload_file_to_comfyapi(
if upload_mime_type is None: if upload_mime_type is None:
request_object = UploadRequest(file_name=filename) request_object = UploadRequest(file_name=filename)
else: else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) request_object = UploadRequest(
file_name=filename, content_type=upload_mime_type
)
create_resp = await sync_op( create_resp = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path="/customers/storage", method="POST"), endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
@ -223,9 +236,9 @@ async def upload_file(
file: BytesIO | str, file: BytesIO | str,
*, *,
content_type: str | None = None, content_type: str | None = None,
max_retries: int = 3, max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = 1.0, retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = 2.0, retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str | None = None, wait_label: str | None = None,
progress_origin_ts: float | None = None, progress_origin_ts: float | None = None,
) -> None: ) -> None:
@ -250,11 +263,15 @@ async def upload_file(
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
else: else:
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request skip_auto_headers.add(
"Content-Type"
) # Don't let aiohttp add Content-Type, it can break the signed request
attempt = 0 attempt = 0
delay = retry_delay delay = retry_delay
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() start_ts = (
progress_origin_ts if progress_origin_ts is not None else time.monotonic()
)
op_uuid = uuid.uuid4().hex[:8] op_uuid = uuid.uuid4().hex[:8]
while True: while True:
attempt += 1 attempt += 1
@ -268,7 +285,9 @@ async def upload_file(
if is_processing_interrupted(): if is_processing_interrupted():
return return
if wait_label: if wait_label:
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) _display_time_progress(
cls, wait_label, int(time.monotonic() - start_ts), None
)
await asyncio.sleep(1.0) await asyncio.sleep(1.0)
except asyncio.CancelledError: except asyncio.CancelledError:
return return
@ -286,10 +305,17 @@ async def upload_file(
) )
sess = aiohttp.ClientSession(timeout=timeout) sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) req = sess.put(
upload_url,
data=data,
headers=headers,
skip_auto_headers=skip_auto_headers,
)
req_task = asyncio.create_task(req) req_task = asyncio.create_task(req)
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task in done and req_task in pending: if monitor_task in done and req_task in pending:
req_task.cancel() req_task.cancel()
@ -317,14 +343,19 @@ async def upload_file(
response_content=body, response_content=body,
error_message=msg, error_message=msg,
) )
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: if (
resp.status in {408, 429, 500, 502, 503, 504}
and attempt <= max_retries
):
await sleep_with_interrupt( await sleep_with_interrupt(
delay, delay,
cls, cls,
wait_label, wait_label,
start_ts, start_ts,
None, None,
display_callback=_display_time_progress if wait_label else None, display_callback=_display_time_progress
if wait_label
else None,
) )
delay *= retry_backoff delay *= retry_backoff
continue continue
@ -366,7 +397,9 @@ async def upload_file(
raise LocalNetworkError( raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again." "Unable to connect to the network. Please check your internet connection and try again."
) from e ) from e
raise ApiServerError("The API service appears unreachable at this time.") from e raise ApiServerError(
"The API service appears unreachable at this time."
) from e
finally: finally:
stop_evt.set() stop_evt.set()
if monitor_task: if monitor_task:
@ -381,7 +414,11 @@ async def upload_file(
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
try: try:
parsed = urlparse(url) parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload")
.strip("/")
.replace("/", "_")
)
except Exception: except Exception:
slug = "upload" slug = "upload"
return f"{method}_{slug}_{op_uuid}_try{attempt}" return f"{method}_{slug}_{op_uuid}_try{attempt}"

View File

@ -0,0 +1,94 @@
"""Tests for configurable retry defaults via environment variables.
Verifies that COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and
COMFY_API_RETRY_BACKOFF environment variables are respected.
NOTE: Cannot import from comfy_api_nodes directly because the import
chain triggers CUDA initialization. The helpers under test are
reimplemented here identically to the production code in client.py.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from unittest.mock import patch
import pytest
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
class TestEnvHelpers:
def test_env_int_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_int("NONEXISTENT_KEY", 42) == 42
def test_env_int_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "10"}):
assert _env_int("TEST_KEY", 42) == 10
def test_env_int_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "not_a_number"}):
assert _env_int("TEST_KEY", 42) == 42
def test_env_float_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_float("NONEXISTENT_KEY", 1.5) == 1.5
def test_env_float_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "2.5"}):
assert _env_float("TEST_KEY", 1.5) == 2.5
def test_env_float_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "bad"}):
assert _env_float("TEST_KEY", 1.5) == 1.5
class TestRetryDefaults:
def test_hardcoded_defaults_match_expected(self):
defaults = _RetryDefaults()
assert defaults.max_retries == 3
assert defaults.retry_delay == 1.0
assert defaults.retry_backoff == 2.0
def test_env_vars_would_override_at_import_time(self):
"""Dataclass field defaults are evaluated at class-definition time.
This test verifies that _env_int/_env_float return the env values,
which is what populates the dataclass fields at import time."""
with patch.dict(os.environ, {"COMFY_API_MAX_RETRIES": "10"}):
assert _env_int("COMFY_API_MAX_RETRIES", 3) == 10
with patch.dict(os.environ, {"COMFY_API_RETRY_DELAY": "3.0"}):
assert _env_float("COMFY_API_RETRY_DELAY", 1.0) == 3.0
with patch.dict(os.environ, {"COMFY_API_RETRY_BACKOFF": "1.5"}):
assert _env_float("COMFY_API_RETRY_BACKOFF", 2.0) == 1.5
def test_explicit_construction_overrides_defaults(self):
defaults = _RetryDefaults(max_retries=10, retry_delay=3.0, retry_backoff=1.5)
assert defaults.max_retries == 10
assert defaults.retry_delay == 3.0
assert defaults.retry_backoff == 1.5
def test_frozen_dataclass(self):
defaults = _RetryDefaults()
with pytest.raises(AttributeError):
defaults.max_retries = 999