From b9abd21cb7ef255a3a1fe719e1688b186ad05060 Mon Sep 17 00:00:00 2001 From: Glary-Bot Date: Sun, 19 Apr 2026 10:48:48 +0000 Subject: [PATCH] feat: make API node retry parameters configurable via environment variables Adds COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and COMFY_API_RETRY_BACKOFF environment variables that override the default retry parameters for all API node HTTP requests (sync_op, sync_op_raw, upload_file, download_url_to_bytesio). Users in regions with unstable networks (e.g. behind the GFW in China) can increase the retry budget to tolerate longer network interruptions: COMFY_API_MAX_RETRIES=10 COMFY_API_RETRY_DELAY=2.0 python main.py Defaults remain unchanged (3 retries, 1.0s delay, 2.0x backoff) when the env vars are not set. --- comfy_api_nodes/util/client.py | 200 ++++++++++++++++++----- comfy_api_nodes/util/download_helpers.py | 52 ++++-- comfy_api_nodes/util/upload_helpers.py | 73 +++++++-- tests/test_retry_defaults.py | 94 +++++++++++ 4 files changed, 346 insertions(+), 73 deletions(-) create mode 100644 tests/test_retry_defaults.py diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a..955f9c67d 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -2,6 +2,7 @@ import asyncio import contextlib import json import logging +import os import time import uuid from collections.abc import Callable, Iterable @@ -32,6 +33,30 @@ from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInte M = TypeVar("M", bound=BaseModel) +def _env_int(key: str, default: int) -> int: + try: + return int(os.environ[key]) + except (KeyError, ValueError): + return default + + +def _env_float(key: str, default: float) -> float: + try: + return float(os.environ[key]) + except (KeyError, ValueError): + return default + + +@dataclass(frozen=True) +class _RetryDefaults: + max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3) + retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0) + retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0) + + +RETRY_DEFAULTS = _RetryDefaults() + + class ApiEndpoint: def __init__( self, @@ -78,11 +103,21 @@ class _PollUIState: price: float | None = None estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: float | None = None # start time of current active interval (None if queued) + active_since: float | None = ( + None # start time of current active interval (None if queued) + ) _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] +COMPLETED_STATUSES = [ + "succeeded", + "succeed", + "success", + "completed", + "finished", + "done", + "complete", +] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] @@ -98,9 +133,9 @@ async def sync_op( content_type: str = "application/json", timeout: float = 3600.0, multipart_parser: Callable | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff: float = 2.0, + max_retries: int = RETRY_DEFAULTS.max_retries, + retry_delay: float = RETRY_DEFAULTS.retry_delay, + retry_backoff: float = RETRY_DEFAULTS.retry_backoff, wait_label: str = "Waiting for server", estimated_duration: int | None = None, final_label_on_success: str | None = "Completed", @@ -131,7 +166,9 @@ async def sync_op( is_rate_limited=is_rate_limited, ) if not isinstance(raw, dict): - raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + raise Exception( + "Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)." + ) return _validate_or_raise(response_model, raw) @@ -178,7 +215,9 @@ async def poll_op( cancel_timeout=cancel_timeout, ) if not isinstance(raw, dict): - raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + raise Exception( + "Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)." + ) return _validate_or_raise(response_model, raw) @@ -192,9 +231,9 @@ async def sync_op_raw( content_type: str = "application/json", timeout: float = 3600.0, multipart_parser: Callable | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff: float = 2.0, + max_retries: int = RETRY_DEFAULTS.max_retries, + retry_delay: float = RETRY_DEFAULTS.retry_delay, + retry_backoff: float = RETRY_DEFAULTS.retry_backoff, wait_label: str = "Waiting for server", estimated_duration: int | None = None, as_binary: bool = False, @@ -269,9 +308,15 @@ async def poll_op_raw( Returns the final JSON response from the poll endpoint. """ - completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) - failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) - queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + completed_states = _normalize_statuses( + COMPLETED_STATUSES if completed_statuses is None else completed_statuses + ) + failed_states = _normalize_statuses( + FAILED_STATUSES if failed_statuses is None else failed_statuses + ) + queued_states = _normalize_statuses( + QUEUED_STATUSES if queued_statuses is None else queued_statuses + ) started = time.monotonic() consumed_attempts = 0 # counts only non-queued polls @@ -289,7 +334,9 @@ async def poll_op_raw( break now = time.monotonic() proc_elapsed = state.base_processing_elapsed + ( - (now - state.active_since) if state.active_since is not None else 0.0 + (now - state.active_since) + if state.active_since is not None + else 0.0 ) _display_time_progress( cls, @@ -361,11 +408,15 @@ async def poll_op_raw( is_queued = status in queued_states if is_queued: - if state.active_since is not None: # If we just moved from active -> queued, close the active interval + if ( + state.active_since is not None + ): # If we just moved from active -> queued, close the active interval state.base_processing_elapsed += now_ts - state.active_since state.active_since = None else: - if state.active_since is None: # If we just moved from queued -> active, open a new active interval + if ( + state.active_since is None + ): # If we just moved from queued -> active, open a new active interval state.active_since = now_ts state.is_queued = is_queued @@ -442,7 +493,9 @@ def _display_text( ) -> None: display_lines: list[str] = [] if status: - display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + display_lines.append( + f"Status: {status.capitalize() if isinstance(status, str) else status}" + ) if price is not None: p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": @@ -450,7 +503,9 @@ def _display_text( if text is not None: display_lines.append(text) if display_lines: - PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + PromptServer.instance.send_progress_text( + "\n".join(display_lines), get_node_id(node_cls) + ) def _display_time_progress( @@ -464,7 +519,11 @@ def _display_time_progress( processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: - pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + pe = ( + processing_elapsed_seconds + if processing_elapsed_seconds is not None + else elapsed_seconds + ) remaining = max(0, int(estimated_total) - int(pe)) time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: @@ -503,7 +562,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: +def _merge_params( + endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None +) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -566,8 +627,14 @@ def _snapshot_request_body_for_logging( filename = file_obj[0] else: filename = getattr(file_obj, "name", field_name) - file_fields.append({"field": field_name, "filename": str(filename or "")}) - return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + file_fields.append( + {"field": field_name, "filename": str(filename or "")} + ) + return { + "_multipart": True, + "form_fields": form_fields, + "file_fields": file_fields, + } if content_type == "application/x-www-form-urlencoded": return data or {} return data or {} @@ -581,7 +648,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) method = cfg.endpoint.method - params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + params = _merge_params( + cfg.endpoint.query_params, method, cfg.data if method == "GET" else None + ) async def _monitor(stop_evt: asyncio.Event, start_ts: float): """Every second: update elapsed time and signal interruption.""" @@ -591,13 +660,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): return if cfg.monitor_progress: _display_time_progress( - cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + cfg.node_cls, + cfg.wait_label, + int(time.monotonic() - start_ts), + cfg.estimated_total, ) await asyncio.sleep(1.0) except asyncio.CancelledError: return # normal shutdown - start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + start_time = ( + cfg.progress_origin_ts + if cfg.progress_origin_ts is not None + else time.monotonic() + ) attempt = 0 delay = cfg.retry_delay rate_limit_attempts = 0 @@ -614,7 +690,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) - payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + payload_headers = ( + {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + ) if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) if cfg.endpoint.headers: @@ -623,7 +701,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_kw: dict[str, Any] = {"headers": payload_headers} if method == "GET": payload_headers.pop("Content-Type", None) - request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + request_body_log = _snapshot_request_body_for_logging( + cfg.content_type, method, cfg.data, cfg.files + ) try: if cfg.monitor_progress: monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) @@ -637,16 +717,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if cfg.multipart_parser and cfg.data: form = cfg.multipart_parser(cfg.data) if not isinstance(form, aiohttp.FormData): - raise ValueError("multipart_parser must return aiohttp.FormData") + raise ValueError( + "multipart_parser must return aiohttp.FormData" + ) else: form = aiohttp.FormData(default_to_multipart=True) if cfg.data: for k, v in cfg.data.items(): if v is None: continue - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + form.add_field( + k, + str(v) if not isinstance(v, (bytes, bytearray)) else v, + ) if cfg.files: - file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + file_iter = ( + cfg.files if isinstance(cfg.files, list) else cfg.files.items() + ) for field_name, file_obj in file_iter: if file_obj is None: continue @@ -660,9 +747,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if isinstance(file_value, BytesIO): with contextlib.suppress(Exception): file_value.seek(0) - form.add_field(field_name, file_value, filename=filename, content_type=content_type) + form.add_field( + field_name, + file_value, + filename=filename, + content_type=content_type, + ) payload_kw["data"] = form - elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + elif ( + cfg.content_type == "application/x-www-form-urlencoded" + and method != "GET" + ): payload_headers["Content-Type"] = "application/x-www-form-urlencoded" payload_kw["data"] = cfg.data or {} elif method != "GET": @@ -685,7 +780,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): tasks = {req_task} if monitor_task: tasks.add(monitor_task) - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) if monitor_task and monitor_task in done: # Interrupted – cancel the request and abort @@ -705,7 +802,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): wait_time = 0.0 retry_label = "" is_rl = resp.status == 429 or ( - cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body) + cfg.is_rate_limited is not None + and cfg.is_rate_limited(resp.status, body) ) if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit: rate_limit_attempts += 1 @@ -713,7 +811,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): rate_limit_delay *= cfg.retry_backoff retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}" should_retry = True - elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries: + elif ( + resp.status in _RETRY_STATUS + and (attempt - rate_limit_attempts) <= cfg.max_retries + ): wait_time = delay delay *= cfg.retry_backoff retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}" @@ -743,7 +844,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.wait_label if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None, cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, + display_callback=_display_time_progress + if cfg.monitor_progress + else None, ) continue msg = _friendly_http_message(resp.status, body) @@ -770,7 +873,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): raise ProcessingInterrupted("Task cancelled") if cfg.monitor_progress: _display_time_progress( - cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + cfg.node_cls, + cfg.wait_label, + int(now - start_time), + cfg.estimated_total, ) bytes_payload = bytes(buff) resp_headers = {k.lower(): v for k, v in resp.headers.items()} @@ -800,9 +906,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload = json.loads(text) if text else {} except json.JSONDecodeError: payload = {"_raw": text} - response_content_to_log = payload if isinstance(payload, dict) else text + response_content_to_log = ( + payload if isinstance(payload, dict) else text + ) with contextlib.suppress(Exception): - extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None + extracted_price = ( + cfg.price_extractor(payload) + if cfg.price_extractor + else None + ) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -844,7 +956,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.wait_label if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None, cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, + display_callback=_display_time_progress + if cfg.monitor_progress + else None, ) delay *= cfg.retry_backoff continue @@ -885,7 +999,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if sess: with contextlib.suppress(Exception): await sess.close() - if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + if ( + operation_succeeded + and cfg.monitor_progress + and cfg.final_label_on_success + ): _display_time_progress( cfg.node_cls, status=cfg.final_label_on_success, diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index aa588d038..13b566945 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -22,7 +22,7 @@ from ._helpers import ( sleep_with_interrupt, to_aiohttp_url, ) -from .client import _diagnose_connectivity +from .client import RETRY_DEFAULTS, _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .conversions import bytesio_to_image_tensor @@ -34,9 +34,9 @@ async def download_url_to_bytesio( dest: BytesIO | IO[bytes] | str | Path | None, *, timeout: float | None = None, - max_retries: int = 5, - retry_delay: float = 1.0, - retry_backoff: float = 2.0, + max_retries: int = max(5, RETRY_DEFAULTS.max_retries), + retry_delay: float = RETRY_DEFAULTS.retry_delay, + retry_backoff: float = RETRY_DEFAULTS.retry_backoff, cls: type[COMFY_IO.ComfyNode] = None, ) -> None: """Stream-download a URL to `dest`. @@ -53,7 +53,9 @@ async def download_url_to_bytesio( ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) """ if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): - raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") + raise ValueError( + "dest must be a path (str|Path) or a binary-writable object providing .write()." + ) attempt = 0 delay = retry_delay @@ -62,7 +64,9 @@ async def download_url_to_bytesio( parsed_url = urlparse(url) if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if cls is None: - raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") + raise ValueError( + "For relative 'cloud' paths, the `cls` parameter is required." + ) url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) headers = get_auth_header(cls) @@ -80,7 +84,9 @@ async def download_url_to_bytesio( try: with contextlib.suppress(Exception): - request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) + request_logger.log_request_response( + operation_id=op_id, request_method="GET", request_url=url + ) session = aiohttp.ClientSession(timeout=timeout_cfg) stop_evt = asyncio.Event() @@ -96,8 +102,12 @@ async def download_url_to_bytesio( monitor_task = asyncio.create_task(_monitor()) - req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) - done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + req_task = asyncio.create_task( + session.get(to_aiohttp_url(url), headers=headers) + ) + done, pending = await asyncio.wait( + {req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED + ) if monitor_task in done and req_task in pending: req_task.cancel() @@ -117,7 +127,11 @@ async def download_url_to_bytesio( body = await resp.json() except (ContentTypeError, ValueError): text = await resp.text() - body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + body = ( + text + if len(text) <= 4096 + else f"[text {len(text)} bytes]" + ) request_logger.log_request_response( operation_id=op_id, request_method="GET", @@ -146,7 +160,9 @@ async def download_url_to_bytesio( written = 0 while True: try: - chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) + chunk = await asyncio.wait_for( + resp.content.read(1024 * 1024), timeout=1.0 + ) except asyncio.TimeoutError: chunk = b"" except asyncio.CancelledError: @@ -195,7 +211,9 @@ async def download_url_to_bytesio( raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e - raise ApiServerError("The remote service appears unreachable at this time.") from e + raise ApiServerError( + "The remote service appears unreachable at this time." + ) from e finally: if stop_evt is not None: stop_evt.set() @@ -237,7 +255,9 @@ async def download_url_to_video_output( ) -> InputImpl.VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() - await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) + await download_url_to_bytesio( + video_url, result, timeout=timeout, max_retries=max_retries, cls=cls + ) return InputImpl.VideoFromFile(result) @@ -256,7 +276,11 @@ async def download_url_as_bytesio( def _generate_operation_id(method: str, url: str, attempt: int) -> str: try: parsed = urlparse(url) - slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") + slug = ( + (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download") + .strip("/") + .replace("/", "_") + ) except Exception: slug = "download" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 6d1d107a1..775c8d0d8 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -15,6 +15,7 @@ from comfy_api.latest import IO, Input, Types from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt from .client import ( + RETRY_DEFAULTS, ApiEndpoint, _diagnose_connectivity, _display_time_progress, @@ -77,13 +78,17 @@ async def upload_images_to_comfyapi( for idx in range(num_to_upload): tensor = tensors[idx] - img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) + img_io = tensor_to_bytesio( + tensor, total_pixels=total_pixels, mime_type=mime_type + ) effective_label = wait_label if wait_label and show_batch_index and num_to_upload > 1: effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" - url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) + url = await upload_file_to_comfyapi( + cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts + ) download_urls.append(url) return download_urls @@ -125,8 +130,12 @@ async def upload_audio_to_comfyapi( sample_rate: int = audio["sample_rate"] waveform: torch.Tensor = audio["waveform"] audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) - return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) + audio_bytes_io = audio_ndarray_to_bytesio( + audio_data_np, sample_rate, container_format, codec_name + ) + return await upload_file_to_comfyapi( + cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type + ) async def upload_video_to_comfyapi( @@ -161,7 +170,9 @@ async def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) + return await upload_file_to_comfyapi( + cls, video_bytes_io, filename, upload_mime_type, wait_label + ) _3D_MIME_TYPES = { @@ -197,7 +208,9 @@ async def upload_file_to_comfyapi( if upload_mime_type is None: request_object = UploadRequest(file_name=filename) else: - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + request_object = UploadRequest( + file_name=filename, content_type=upload_mime_type + ) create_resp = await sync_op( cls, endpoint=ApiEndpoint(path="/customers/storage", method="POST"), @@ -223,9 +236,9 @@ async def upload_file( file: BytesIO | str, *, content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff: float = 2.0, + max_retries: int = RETRY_DEFAULTS.max_retries, + retry_delay: float = RETRY_DEFAULTS.retry_delay, + retry_backoff: float = RETRY_DEFAULTS.retry_backoff, wait_label: str | None = None, progress_origin_ts: float | None = None, ) -> None: @@ -250,11 +263,15 @@ async def upload_file( if content_type: headers["Content-Type"] = content_type else: - skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + skip_auto_headers.add( + "Content-Type" + ) # Don't let aiohttp add Content-Type, it can break the signed request attempt = 0 delay = retry_delay - start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() + start_ts = ( + progress_origin_ts if progress_origin_ts is not None else time.monotonic() + ) op_uuid = uuid.uuid4().hex[:8] while True: attempt += 1 @@ -268,7 +285,9 @@ async def upload_file( if is_processing_interrupted(): return if wait_label: - _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + _display_time_progress( + cls, wait_label, int(time.monotonic() - start_ts), None + ) await asyncio.sleep(1.0) except asyncio.CancelledError: return @@ -286,10 +305,17 @@ async def upload_file( ) sess = aiohttp.ClientSession(timeout=timeout) - req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req = sess.put( + upload_url, + data=data, + headers=headers, + skip_auto_headers=skip_auto_headers, + ) req_task = asyncio.create_task(req) - done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + {req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED + ) if monitor_task in done and req_task in pending: req_task.cancel() @@ -317,14 +343,19 @@ async def upload_file( response_content=body, error_message=msg, ) - if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + if ( + resp.status in {408, 429, 500, 502, 503, 504} + and attempt <= max_retries + ): await sleep_with_interrupt( delay, cls, wait_label, start_ts, None, - display_callback=_display_time_progress if wait_label else None, + display_callback=_display_time_progress + if wait_label + else None, ) delay *= retry_backoff continue @@ -366,7 +397,9 @@ async def upload_file( raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e - raise ApiServerError("The API service appears unreachable at this time.") from e + raise ApiServerError( + "The API service appears unreachable at this time." + ) from e finally: stop_evt.set() if monitor_task: @@ -381,7 +414,11 @@ async def upload_file( def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: try: parsed = urlparse(url) - slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + slug = ( + (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload") + .strip("/") + .replace("/", "_") + ) except Exception: slug = "upload" return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/tests/test_retry_defaults.py b/tests/test_retry_defaults.py new file mode 100644 index 000000000..34327b790 --- /dev/null +++ b/tests/test_retry_defaults.py @@ -0,0 +1,94 @@ +"""Tests for configurable retry defaults via environment variables. + +Verifies that COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and +COMFY_API_RETRY_BACKOFF environment variables are respected. + +NOTE: Cannot import from comfy_api_nodes directly because the import +chain triggers CUDA initialization. The helpers under test are +reimplemented here identically to the production code in client.py. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from unittest.mock import patch + +import pytest + + +def _env_int(key: str, default: int) -> int: + try: + return int(os.environ[key]) + except (KeyError, ValueError): + return default + + +def _env_float(key: str, default: float) -> float: + try: + return float(os.environ[key]) + except (KeyError, ValueError): + return default + + +@dataclass(frozen=True) +class _RetryDefaults: + max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3) + retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0) + retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0) + + +class TestEnvHelpers: + def test_env_int_returns_default_when_unset(self): + with patch.dict(os.environ, {}, clear=True): + assert _env_int("NONEXISTENT_KEY", 42) == 42 + + def test_env_int_returns_env_value(self): + with patch.dict(os.environ, {"TEST_KEY": "10"}): + assert _env_int("TEST_KEY", 42) == 10 + + def test_env_int_returns_default_on_invalid_value(self): + with patch.dict(os.environ, {"TEST_KEY": "not_a_number"}): + assert _env_int("TEST_KEY", 42) == 42 + + def test_env_float_returns_default_when_unset(self): + with patch.dict(os.environ, {}, clear=True): + assert _env_float("NONEXISTENT_KEY", 1.5) == 1.5 + + def test_env_float_returns_env_value(self): + with patch.dict(os.environ, {"TEST_KEY": "2.5"}): + assert _env_float("TEST_KEY", 1.5) == 2.5 + + def test_env_float_returns_default_on_invalid_value(self): + with patch.dict(os.environ, {"TEST_KEY": "bad"}): + assert _env_float("TEST_KEY", 1.5) == 1.5 + + +class TestRetryDefaults: + def test_hardcoded_defaults_match_expected(self): + defaults = _RetryDefaults() + assert defaults.max_retries == 3 + assert defaults.retry_delay == 1.0 + assert defaults.retry_backoff == 2.0 + + def test_env_vars_would_override_at_import_time(self): + """Dataclass field defaults are evaluated at class-definition time. + This test verifies that _env_int/_env_float return the env values, + which is what populates the dataclass fields at import time.""" + with patch.dict(os.environ, {"COMFY_API_MAX_RETRIES": "10"}): + assert _env_int("COMFY_API_MAX_RETRIES", 3) == 10 + with patch.dict(os.environ, {"COMFY_API_RETRY_DELAY": "3.0"}): + assert _env_float("COMFY_API_RETRY_DELAY", 1.0) == 3.0 + with patch.dict(os.environ, {"COMFY_API_RETRY_BACKOFF": "1.5"}): + assert _env_float("COMFY_API_RETRY_BACKOFF", 2.0) == 1.5 + + def test_explicit_construction_overrides_defaults(self): + defaults = _RetryDefaults(max_retries=10, retry_delay=3.0, retry_backoff=1.5) + assert defaults.max_retries == 10 + assert defaults.retry_delay == 3.0 + assert defaults.retry_backoff == 1.5 + + def test_frozen_dataclass(self): + defaults = _RetryDefaults() + with pytest.raises(AttributeError): + defaults.max_retries = 999