mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
feat: make API node retry parameters configurable via environment variables
Adds COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and COMFY_API_RETRY_BACKOFF environment variables that override the default retry parameters for all API node HTTP requests (sync_op, sync_op_raw, upload_file, download_url_to_bytesio). Users in regions with unstable networks (e.g. behind the GFW in China) can increase the retry budget to tolerate longer network interruptions: COMFY_API_MAX_RETRIES=10 COMFY_API_RETRY_DELAY=2.0 python main.py Defaults remain unchanged (3 retries, 1.0s delay, 2.0x backoff) when the env vars are not set.
This commit is contained in:
parent
138571da95
commit
b9abd21cb7
@ -2,6 +2,7 @@ import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
@ -32,6 +33,30 @@ from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInte
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
def _env_int(key: str, default: int) -> int:
|
||||
try:
|
||||
return int(os.environ[key])
|
||||
except (KeyError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _env_float(key: str, default: float) -> float:
|
||||
try:
|
||||
return float(os.environ[key])
|
||||
except (KeyError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RetryDefaults:
|
||||
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
|
||||
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
|
||||
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
|
||||
|
||||
|
||||
RETRY_DEFAULTS = _RetryDefaults()
|
||||
|
||||
|
||||
class ApiEndpoint:
|
||||
def __init__(
|
||||
self,
|
||||
@ -78,11 +103,21 @@ class _PollUIState:
|
||||
price: float | None = None
|
||||
estimated_duration: int | None = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
active_since: float | None = (
|
||||
None # start time of current active interval (None if queued)
|
||||
)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
COMPLETED_STATUSES = [
|
||||
"succeeded",
|
||||
"succeed",
|
||||
"success",
|
||||
"completed",
|
||||
"finished",
|
||||
"done",
|
||||
"complete",
|
||||
]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||
|
||||
@ -98,9 +133,9 @@ async def sync_op(
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
max_retries: int = RETRY_DEFAULTS.max_retries,
|
||||
retry_delay: float = RETRY_DEFAULTS.retry_delay,
|
||||
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: int | None = None,
|
||||
final_label_on_success: str | None = "Completed",
|
||||
@ -131,7 +166,9 @@ async def sync_op(
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
raise Exception(
|
||||
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
|
||||
)
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
@ -178,7 +215,9 @@ async def poll_op(
|
||||
cancel_timeout=cancel_timeout,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
raise Exception(
|
||||
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
|
||||
)
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
@ -192,9 +231,9 @@ async def sync_op_raw(
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
max_retries: int = RETRY_DEFAULTS.max_retries,
|
||||
retry_delay: float = RETRY_DEFAULTS.retry_delay,
|
||||
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: int | None = None,
|
||||
as_binary: bool = False,
|
||||
@ -269,9 +308,15 @@ async def poll_op_raw(
|
||||
|
||||
Returns the final JSON response from the poll endpoint.
|
||||
"""
|
||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||||
completed_states = _normalize_statuses(
|
||||
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
|
||||
)
|
||||
failed_states = _normalize_statuses(
|
||||
FAILED_STATUSES if failed_statuses is None else failed_statuses
|
||||
)
|
||||
queued_states = _normalize_statuses(
|
||||
QUEUED_STATUSES if queued_statuses is None else queued_statuses
|
||||
)
|
||||
started = time.monotonic()
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
@ -289,7 +334,9 @@ async def poll_op_raw(
|
||||
break
|
||||
now = time.monotonic()
|
||||
proc_elapsed = state.base_processing_elapsed + (
|
||||
(now - state.active_since) if state.active_since is not None else 0.0
|
||||
(now - state.active_since)
|
||||
if state.active_since is not None
|
||||
else 0.0
|
||||
)
|
||||
_display_time_progress(
|
||||
cls,
|
||||
@ -361,11 +408,15 @@ async def poll_op_raw(
|
||||
is_queued = status in queued_states
|
||||
|
||||
if is_queued:
|
||||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||||
if (
|
||||
state.active_since is not None
|
||||
): # If we just moved from active -> queued, close the active interval
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
else:
|
||||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||||
if (
|
||||
state.active_since is None
|
||||
): # If we just moved from queued -> active, open a new active interval
|
||||
state.active_since = now_ts
|
||||
|
||||
state.is_queued = is_queued
|
||||
@ -442,7 +493,9 @@ def _display_text(
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
display_lines.append(
|
||||
f"Status: {status.capitalize() if isinstance(status, str) else status}"
|
||||
)
|
||||
if price is not None:
|
||||
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
||||
if p != "0":
|
||||
@ -450,7 +503,9 @@ def _display_text(
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||||
PromptServer.instance.send_progress_text(
|
||||
"\n".join(display_lines), get_node_id(node_cls)
|
||||
)
|
||||
|
||||
|
||||
def _display_time_progress(
|
||||
@ -464,7 +519,11 @@ def _display_time_progress(
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
pe = (
|
||||
processing_elapsed_seconds
|
||||
if processing_elapsed_seconds is not None
|
||||
else elapsed_seconds
|
||||
)
|
||||
remaining = max(0, int(estimated_total) - int(pe))
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
@ -503,7 +562,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
def _merge_params(
|
||||
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
@ -566,8 +627,14 @@ def _snapshot_request_body_for_logging(
|
||||
filename = file_obj[0]
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||||
file_fields.append(
|
||||
{"field": field_name, "filename": str(filename or "")}
|
||||
)
|
||||
return {
|
||||
"_multipart": True,
|
||||
"form_fields": form_fields,
|
||||
"file_fields": file_fields,
|
||||
}
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
return data or {}
|
||||
return data or {}
|
||||
@ -581,7 +648,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
|
||||
method = cfg.endpoint.method
|
||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||
params = _merge_params(
|
||||
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
|
||||
)
|
||||
|
||||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||
"""Every second: update elapsed time and signal interruption."""
|
||||
@ -591,13 +660,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
return
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||||
cfg.node_cls,
|
||||
cfg.wait_label,
|
||||
int(time.monotonic() - start_ts),
|
||||
cfg.estimated_total,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return # normal shutdown
|
||||
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
start_time = (
|
||||
cfg.progress_origin_ts
|
||||
if cfg.progress_origin_ts is not None
|
||||
else time.monotonic()
|
||||
)
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
@ -614,7 +690,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
payload_headers = (
|
||||
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
@ -623,7 +701,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||
if method == "GET":
|
||||
payload_headers.pop("Content-Type", None)
|
||||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||||
request_body_log = _snapshot_request_body_for_logging(
|
||||
cfg.content_type, method, cfg.data, cfg.files
|
||||
)
|
||||
try:
|
||||
if cfg.monitor_progress:
|
||||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||
@ -637,16 +717,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
if cfg.multipart_parser and cfg.data:
|
||||
form = cfg.multipart_parser(cfg.data)
|
||||
if not isinstance(form, aiohttp.FormData):
|
||||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||||
raise ValueError(
|
||||
"multipart_parser must return aiohttp.FormData"
|
||||
)
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if cfg.data:
|
||||
for k, v in cfg.data.items():
|
||||
if v is None:
|
||||
continue
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
form.add_field(
|
||||
k,
|
||||
str(v) if not isinstance(v, (bytes, bytearray)) else v,
|
||||
)
|
||||
if cfg.files:
|
||||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
file_iter = (
|
||||
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
)
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
@ -660,9 +747,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
if isinstance(file_value, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file_value.seek(0)
|
||||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||||
form.add_field(
|
||||
field_name,
|
||||
file_value,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
payload_kw["data"] = form
|
||||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||||
elif (
|
||||
cfg.content_type == "application/x-www-form-urlencoded"
|
||||
and method != "GET"
|
||||
):
|
||||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
payload_kw["data"] = cfg.data or {}
|
||||
elif method != "GET":
|
||||
@ -685,7 +780,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
tasks = {req_task}
|
||||
if monitor_task:
|
||||
tasks.add(monitor_task)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if monitor_task and monitor_task in done:
|
||||
# Interrupted – cancel the request and abort
|
||||
@ -705,7 +802,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
cfg.is_rate_limited is not None
|
||||
and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
@ -713,7 +811,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
elif (
|
||||
resp.status in _RETRY_STATUS
|
||||
and (attempt - rate_limit_attempts) <= cfg.max_retries
|
||||
):
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
@ -743,7 +844,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
display_callback=_display_time_progress
|
||||
if cfg.monitor_progress
|
||||
else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
@ -770,7 +873,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
cfg.node_cls,
|
||||
cfg.wait_label,
|
||||
int(now - start_time),
|
||||
cfg.estimated_total,
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||
@ -800,9 +906,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload = json.loads(text) if text else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
response_content_to_log = (
|
||||
payload if isinstance(payload, dict) else text
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
extracted_price = (
|
||||
cfg.price_extractor(payload)
|
||||
if cfg.price_extractor
|
||||
else None
|
||||
)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
@ -844,7 +956,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
display_callback=_display_time_progress
|
||||
if cfg.monitor_progress
|
||||
else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
@ -885,7 +999,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||||
if (
|
||||
operation_succeeded
|
||||
and cfg.monitor_progress
|
||||
and cfg.final_label_on_success
|
||||
):
|
||||
_display_time_progress(
|
||||
cfg.node_cls,
|
||||
status=cfg.final_label_on_success,
|
||||
|
||||
@ -22,7 +22,7 @@ from ._helpers import (
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
)
|
||||
from .client import _diagnose_connectivity
|
||||
from .client import RETRY_DEFAULTS, _diagnose_connectivity
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import bytesio_to_image_tensor
|
||||
|
||||
@ -34,9 +34,9 @@ async def download_url_to_bytesio(
|
||||
dest: BytesIO | IO[bytes] | str | Path | None,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
max_retries: int = max(5, RETRY_DEFAULTS.max_retries),
|
||||
retry_delay: float = RETRY_DEFAULTS.retry_delay,
|
||||
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> None:
|
||||
"""Stream-download a URL to `dest`.
|
||||
@ -53,7 +53,9 @@ async def download_url_to_bytesio(
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
|
||||
"""
|
||||
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
|
||||
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
|
||||
raise ValueError(
|
||||
"dest must be a path (str|Path) or a binary-writable object providing .write()."
|
||||
)
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
@ -62,7 +64,9 @@ async def download_url_to_bytesio(
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
raise ValueError(
|
||||
"For relative 'cloud' paths, the `cls` parameter is required."
|
||||
)
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
|
||||
@ -80,7 +84,9 @@ async def download_url_to_bytesio(
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id, request_method="GET", request_url=url
|
||||
)
|
||||
|
||||
session = aiohttp.ClientSession(timeout=timeout_cfg)
|
||||
stop_evt = asyncio.Event()
|
||||
@ -96,8 +102,12 @@ async def download_url_to_bytesio(
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
|
||||
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
req_task = asyncio.create_task(
|
||||
session.get(to_aiohttp_url(url), headers=headers)
|
||||
)
|
||||
done, pending = await asyncio.wait(
|
||||
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
@ -117,7 +127,11 @@ async def download_url_to_bytesio(
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, ValueError):
|
||||
text = await resp.text()
|
||||
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
|
||||
body = (
|
||||
text
|
||||
if len(text) <= 4096
|
||||
else f"[text {len(text)} bytes]"
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
@ -146,7 +160,9 @@ async def download_url_to_bytesio(
|
||||
written = 0
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
|
||||
chunk = await asyncio.wait_for(
|
||||
resp.content.read(1024 * 1024), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
chunk = b""
|
||||
except asyncio.CancelledError:
|
||||
@ -195,7 +211,9 @@ async def download_url_to_bytesio(
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The remote service appears unreachable at this time.") from e
|
||||
raise ApiServerError(
|
||||
"The remote service appears unreachable at this time."
|
||||
) from e
|
||||
finally:
|
||||
if stop_evt is not None:
|
||||
stop_evt.set()
|
||||
@ -237,7 +255,9 @@ async def download_url_to_video_output(
|
||||
) -> InputImpl.VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||
await download_url_to_bytesio(
|
||||
video_url, result, timeout=timeout, max_retries=max_retries, cls=cls
|
||||
)
|
||||
return InputImpl.VideoFromFile(result)
|
||||
|
||||
|
||||
@ -256,7 +276,11 @@ async def download_url_as_bytesio(
|
||||
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
|
||||
slug = (
|
||||
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download")
|
||||
.strip("/")
|
||||
.replace("/", "_")
|
||||
)
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@ -15,6 +15,7 @@ from comfy_api.latest import IO, Input, Types
|
||||
from . import request_logger
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
from .client import (
|
||||
RETRY_DEFAULTS,
|
||||
ApiEndpoint,
|
||||
_diagnose_connectivity,
|
||||
_display_time_progress,
|
||||
@ -77,13 +78,17 @@ async def upload_images_to_comfyapi(
|
||||
|
||||
for idx in range(num_to_upload):
|
||||
tensor = tensors[idx]
|
||||
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
|
||||
img_io = tensor_to_bytesio(
|
||||
tensor, total_pixels=total_pixels, mime_type=mime_type
|
||||
)
|
||||
|
||||
effective_label = wait_label
|
||||
if wait_label and show_batch_index and num_to_upload > 1:
|
||||
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
|
||||
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
|
||||
url = await upload_file_to_comfyapi(
|
||||
cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts
|
||||
)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
@ -125,8 +130,12 @@ async def upload_audio_to_comfyapi(
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
return await upload_file_to_comfyapi(
|
||||
cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type
|
||||
)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
@ -161,7 +170,9 @@ async def upload_video_to_comfyapi(
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||
return await upload_file_to_comfyapi(
|
||||
cls, video_bytes_io, filename, upload_mime_type, wait_label
|
||||
)
|
||||
|
||||
|
||||
_3D_MIME_TYPES = {
|
||||
@ -197,7 +208,9 @@ async def upload_file_to_comfyapi(
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
request_object = UploadRequest(
|
||||
file_name=filename, content_type=upload_mime_type
|
||||
)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
|
||||
@ -223,9 +236,9 @@ async def upload_file(
|
||||
file: BytesIO | str,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
max_retries: int = RETRY_DEFAULTS.max_retries,
|
||||
retry_delay: float = RETRY_DEFAULTS.retry_delay,
|
||||
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
|
||||
wait_label: str | None = None,
|
||||
progress_origin_ts: float | None = None,
|
||||
) -> None:
|
||||
@ -250,11 +263,15 @@ async def upload_file(
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
|
||||
skip_auto_headers.add(
|
||||
"Content-Type"
|
||||
) # Don't let aiohttp add Content-Type, it can break the signed request
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
|
||||
start_ts = (
|
||||
progress_origin_ts if progress_origin_ts is not None else time.monotonic()
|
||||
)
|
||||
op_uuid = uuid.uuid4().hex[:8]
|
||||
while True:
|
||||
attempt += 1
|
||||
@ -268,7 +285,9 @@ async def upload_file(
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if wait_label:
|
||||
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
|
||||
_display_time_progress(
|
||||
cls, wait_label, int(time.monotonic() - start_ts), None
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
@ -286,10 +305,17 @@ async def upload_file(
|
||||
)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
req = sess.put(
|
||||
upload_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
skip_auto_headers=skip_auto_headers,
|
||||
)
|
||||
req_task = asyncio.create_task(req)
|
||||
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
done, pending = await asyncio.wait(
|
||||
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
@ -317,14 +343,19 @@ async def upload_file(
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
|
||||
if (
|
||||
resp.status in {408, 429, 500, 502, 503, 504}
|
||||
and attempt <= max_retries
|
||||
):
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
display_callback=_display_time_progress
|
||||
if wait_label
|
||||
else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
@ -366,7 +397,9 @@ async def upload_file(
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The API service appears unreachable at this time.") from e
|
||||
raise ApiServerError(
|
||||
"The API service appears unreachable at this time."
|
||||
) from e
|
||||
finally:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
@ -381,7 +414,11 @@ async def upload_file(
|
||||
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
|
||||
slug = (
|
||||
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload")
|
||||
.strip("/")
|
||||
.replace("/", "_")
|
||||
)
|
||||
except Exception:
|
||||
slug = "upload"
|
||||
return f"{method}_{slug}_{op_uuid}_try{attempt}"
|
||||
|
||||
94
tests/test_retry_defaults.py
Normal file
94
tests/test_retry_defaults.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Tests for configurable retry defaults via environment variables.
|
||||
|
||||
Verifies that COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and
|
||||
COMFY_API_RETRY_BACKOFF environment variables are respected.
|
||||
|
||||
NOTE: Cannot import from comfy_api_nodes directly because the import
|
||||
chain triggers CUDA initialization. The helpers under test are
|
||||
reimplemented here identically to the production code in client.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _env_int(key: str, default: int) -> int:
|
||||
try:
|
||||
return int(os.environ[key])
|
||||
except (KeyError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _env_float(key: str, default: float) -> float:
|
||||
try:
|
||||
return float(os.environ[key])
|
||||
except (KeyError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RetryDefaults:
|
||||
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
|
||||
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
|
||||
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
|
||||
|
||||
|
||||
class TestEnvHelpers:
|
||||
def test_env_int_returns_default_when_unset(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _env_int("NONEXISTENT_KEY", 42) == 42
|
||||
|
||||
def test_env_int_returns_env_value(self):
|
||||
with patch.dict(os.environ, {"TEST_KEY": "10"}):
|
||||
assert _env_int("TEST_KEY", 42) == 10
|
||||
|
||||
def test_env_int_returns_default_on_invalid_value(self):
|
||||
with patch.dict(os.environ, {"TEST_KEY": "not_a_number"}):
|
||||
assert _env_int("TEST_KEY", 42) == 42
|
||||
|
||||
def test_env_float_returns_default_when_unset(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _env_float("NONEXISTENT_KEY", 1.5) == 1.5
|
||||
|
||||
def test_env_float_returns_env_value(self):
|
||||
with patch.dict(os.environ, {"TEST_KEY": "2.5"}):
|
||||
assert _env_float("TEST_KEY", 1.5) == 2.5
|
||||
|
||||
def test_env_float_returns_default_on_invalid_value(self):
|
||||
with patch.dict(os.environ, {"TEST_KEY": "bad"}):
|
||||
assert _env_float("TEST_KEY", 1.5) == 1.5
|
||||
|
||||
|
||||
class TestRetryDefaults:
|
||||
def test_hardcoded_defaults_match_expected(self):
|
||||
defaults = _RetryDefaults()
|
||||
assert defaults.max_retries == 3
|
||||
assert defaults.retry_delay == 1.0
|
||||
assert defaults.retry_backoff == 2.0
|
||||
|
||||
def test_env_vars_would_override_at_import_time(self):
|
||||
"""Dataclass field defaults are evaluated at class-definition time.
|
||||
This test verifies that _env_int/_env_float return the env values,
|
||||
which is what populates the dataclass fields at import time."""
|
||||
with patch.dict(os.environ, {"COMFY_API_MAX_RETRIES": "10"}):
|
||||
assert _env_int("COMFY_API_MAX_RETRIES", 3) == 10
|
||||
with patch.dict(os.environ, {"COMFY_API_RETRY_DELAY": "3.0"}):
|
||||
assert _env_float("COMFY_API_RETRY_DELAY", 1.0) == 3.0
|
||||
with patch.dict(os.environ, {"COMFY_API_RETRY_BACKOFF": "1.5"}):
|
||||
assert _env_float("COMFY_API_RETRY_BACKOFF", 2.0) == 1.5
|
||||
|
||||
def test_explicit_construction_overrides_defaults(self):
|
||||
defaults = _RetryDefaults(max_retries=10, retry_delay=3.0, retry_backoff=1.5)
|
||||
assert defaults.max_retries == 10
|
||||
assert defaults.retry_delay == 3.0
|
||||
assert defaults.retry_backoff == 1.5
|
||||
|
||||
def test_frozen_dataclass(self):
|
||||
defaults = _RetryDefaults()
|
||||
with pytest.raises(AttributeError):
|
||||
defaults.max_retries = 999
|
||||
Loading…
Reference in New Issue
Block a user