This commit is contained in:
Christian Byrne 2026-04-19 04:00:07 -07:00 committed by GitHub
commit 022c4d1d41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 346 additions and 73 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import contextlib
import json
import logging
import os
import time
import uuid
from collections.abc import Callable, Iterable
@ -32,6 +33,30 @@ from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInte
M = TypeVar("M", bound=BaseModel)
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
RETRY_DEFAULTS = _RetryDefaults()
class ApiEndpoint:
def __init__(
self,
@ -78,11 +103,21 @@ class _PollUIState:
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued)
active_since: float | None = (
None # start time of current active interval (None if queued)
)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -98,9 +133,9 @@ async def sync_op(
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
final_label_on_success: str | None = "Completed",
@ -131,7 +166,9 @@ async def sync_op(
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -178,7 +215,9 @@ async def poll_op(
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -192,9 +231,9 @@ async def sync_op_raw(
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
as_binary: bool = False,
@ -269,9 +308,15 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
completed_states = _normalize_statuses(
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
)
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
@ -289,7 +334,9 @@ async def poll_op_raw(
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
(now - state.active_since)
if state.active_since is not None
else 0.0
)
_display_time_progress(
cls,
@ -361,11 +408,15 @@ async def poll_op_raw(
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
@ -442,7 +493,9 @@ def _display_text(
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0":
@ -450,7 +503,9 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
def _display_time_progress(
@ -464,7 +519,11 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
@ -503,7 +562,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -566,8 +627,14 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
file_fields.append(
{"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
@ -581,7 +648,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
@ -591,13 +660,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
@ -614,7 +690,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
@ -623,7 +701,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -637,16 +717,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
for field_name, file_obj in file_iter:
if file_obj is None:
continue
@ -660,9 +747,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
@ -685,7 +780,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
@ -705,7 +802,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
@ -713,7 +811,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -743,7 +844,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
continue
msg = _friendly_http_message(resp.status, body)
@ -770,7 +873,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -800,9 +906,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
response_content_to_log = (
payload if isinstance(payload, dict) else text
)
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -844,7 +956,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
delay *= cfg.retry_backoff
continue
@ -885,7 +999,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,

View File

@ -22,7 +22,7 @@ from ._helpers import (
sleep_with_interrupt,
to_aiohttp_url,
)
from .client import _diagnose_connectivity
from .client import RETRY_DEFAULTS, _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor
@ -34,9 +34,9 @@ async def download_url_to_bytesio(
dest: BytesIO | IO[bytes] | str | Path | None,
*,
timeout: float | None = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = max(5, RETRY_DEFAULTS.max_retries),
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
cls: type[COMFY_IO.ComfyNode] = None,
) -> None:
"""Stream-download a URL to `dest`.
@ -53,7 +53,9 @@ async def download_url_to_bytesio(
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
"""
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
raise ValueError(
"dest must be a path (str|Path) or a binary-writable object providing .write()."
)
attempt = 0
delay = retry_delay
@ -62,7 +64,9 @@ async def download_url_to_bytesio(
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
raise ValueError(
"For relative 'cloud' paths, the `cls` parameter is required."
)
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
@ -80,7 +84,9 @@ async def download_url_to_bytesio(
try:
with contextlib.suppress(Exception):
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
request_logger.log_request_response(
operation_id=op_id, request_method="GET", request_url=url
)
session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event()
@ -96,8 +102,12 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
req_task = asyncio.create_task(
session.get(to_aiohttp_url(url), headers=headers)
)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task in done and req_task in pending:
req_task.cancel()
@ -117,7 +127,11 @@ async def download_url_to_bytesio(
body = await resp.json()
except (ContentTypeError, ValueError):
text = await resp.text()
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
body = (
text
if len(text) <= 4096
else f"[text {len(text)} bytes]"
)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
@ -146,7 +160,9 @@ async def download_url_to_bytesio(
written = 0
while True:
try:
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
chunk = await asyncio.wait_for(
resp.content.read(1024 * 1024), timeout=1.0
)
except asyncio.TimeoutError:
chunk = b""
except asyncio.CancelledError:
@ -195,7 +211,9 @@ async def download_url_to_bytesio(
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e
raise ApiServerError(
"The remote service appears unreachable at this time."
) from e
finally:
if stop_evt is not None:
stop_evt.set()
@ -237,7 +255,9 @@ async def download_url_to_video_output(
) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
await download_url_to_bytesio(
video_url, result, timeout=timeout, max_retries=max_retries, cls=cls
)
return InputImpl.VideoFromFile(result)
@ -256,7 +276,11 @@ async def download_url_as_bytesio(
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download")
.strip("/")
.replace("/", "_")
)
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@ -15,6 +15,7 @@ from comfy_api.latest import IO, Input, Types
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
RETRY_DEFAULTS,
ApiEndpoint,
_diagnose_connectivity,
_display_time_progress,
@ -77,13 +78,17 @@ async def upload_images_to_comfyapi(
for idx in range(num_to_upload):
tensor = tensors[idx]
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
img_io = tensor_to_bytesio(
tensor, total_pixels=total_pixels, mime_type=mime_type
)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
url = await upload_file_to_comfyapi(
cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts
)
download_urls.append(url)
return download_urls
@ -125,8 +130,12 @@ async def upload_audio_to_comfyapi(
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(
cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type
)
async def upload_video_to_comfyapi(
@ -161,7 +170,9 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
return await upload_file_to_comfyapi(
cls, video_bytes_io, filename, upload_mime_type, wait_label
)
_3D_MIME_TYPES = {
@ -197,7 +208,9 @@ async def upload_file_to_comfyapi(
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
request_object = UploadRequest(
file_name=filename, content_type=upload_mime_type
)
create_resp = await sync_op(
cls,
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
@ -223,9 +236,9 @@ async def upload_file(
file: BytesIO | str,
*,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None:
@ -250,11 +263,15 @@ async def upload_file(
if content_type:
headers["Content-Type"] = content_type
else:
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
skip_auto_headers.add(
"Content-Type"
) # Don't let aiohttp add Content-Type, it can break the signed request
attempt = 0
delay = retry_delay
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
start_ts = (
progress_origin_ts if progress_origin_ts is not None else time.monotonic()
)
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1
@ -268,7 +285,9 @@ async def upload_file(
if is_processing_interrupted():
return
if wait_label:
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
_display_time_progress(
cls, wait_label, int(time.monotonic() - start_ts), None
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
@ -286,10 +305,17 @@ async def upload_file(
)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
req = sess.put(
upload_url,
data=data,
headers=headers,
skip_auto_headers=skip_auto_headers,
)
req_task = asyncio.create_task(req)
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task in done and req_task in pending:
req_task.cancel()
@ -317,14 +343,19 @@ async def upload_file(
response_content=body,
error_message=msg,
)
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
if (
resp.status in {408, 429, 500, 502, 503, 504}
and attempt <= max_retries
):
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress if wait_label else None,
display_callback=_display_time_progress
if wait_label
else None,
)
delay *= retry_backoff
continue
@ -366,7 +397,9 @@ async def upload_file(
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The API service appears unreachable at this time.") from e
raise ApiServerError(
"The API service appears unreachable at this time."
) from e
finally:
stop_evt.set()
if monitor_task:
@ -381,7 +414,11 @@ async def upload_file(
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload")
.strip("/")
.replace("/", "_")
)
except Exception:
slug = "upload"
return f"{method}_{slug}_{op_uuid}_try{attempt}"

View File

@ -0,0 +1,94 @@
"""Tests for configurable retry defaults via environment variables.
Verifies that COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and
COMFY_API_RETRY_BACKOFF environment variables are respected.
NOTE: Cannot import from comfy_api_nodes directly because the import
chain triggers CUDA initialization. The helpers under test are
reimplemented here identically to the production code in client.py.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from unittest.mock import patch
import pytest
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
class TestEnvHelpers:
def test_env_int_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_int("NONEXISTENT_KEY", 42) == 42
def test_env_int_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "10"}):
assert _env_int("TEST_KEY", 42) == 10
def test_env_int_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "not_a_number"}):
assert _env_int("TEST_KEY", 42) == 42
def test_env_float_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_float("NONEXISTENT_KEY", 1.5) == 1.5
def test_env_float_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "2.5"}):
assert _env_float("TEST_KEY", 1.5) == 2.5
def test_env_float_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "bad"}):
assert _env_float("TEST_KEY", 1.5) == 1.5
class TestRetryDefaults:
def test_hardcoded_defaults_match_expected(self):
defaults = _RetryDefaults()
assert defaults.max_retries == 3
assert defaults.retry_delay == 1.0
assert defaults.retry_backoff == 2.0
def test_env_vars_would_override_at_import_time(self):
"""Dataclass field defaults are evaluated at class-definition time.
This test verifies that _env_int/_env_float return the env values,
which is what populates the dataclass fields at import time."""
with patch.dict(os.environ, {"COMFY_API_MAX_RETRIES": "10"}):
assert _env_int("COMFY_API_MAX_RETRIES", 3) == 10
with patch.dict(os.environ, {"COMFY_API_RETRY_DELAY": "3.0"}):
assert _env_float("COMFY_API_RETRY_DELAY", 1.0) == 3.0
with patch.dict(os.environ, {"COMFY_API_RETRY_BACKOFF": "1.5"}):
assert _env_float("COMFY_API_RETRY_BACKOFF", 2.0) == 1.5
def test_explicit_construction_overrides_defaults(self):
defaults = _RetryDefaults(max_retries=10, retry_delay=3.0, retry_backoff=1.5)
assert defaults.max_retries == 10
assert defaults.retry_delay == 3.0
assert defaults.retry_backoff == 1.5
def test_frozen_dataclass(self):
defaults = _RetryDefaults()
with pytest.raises(AttributeError):
defaults.max_retries = 999