mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
bugfix: fix typo in apply_directory for custom_nodes_directory
allow for PATH style ';' delimited custom_node directories.
change delimiter type for seperate folders per platform.
feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645)
Fix qwen controlnet regression. (#10657)
Enable pinned memory by default on Nvidia. (#10656)
Removed the --fast pinned_memory flag.
You can use --disable-pinned-memory to disable it. Please report if it
causes any issues.
Pinned mem also seems to work on AMD. (#10658)
Remove environment variable.
Removed environment variable fallback for custom nodes directory.
Update documentation for custom nodes directory
Clarified documentation on custom nodes directory argument, removed documentation on environment variable
Clarify release cycle. (#10667)
Tell users they need to upload their logs in bug reports. (#10671)
mm: guard against double pin and unpin explicitly (#10672)
As commented, if you let cuda be the one to detect double pin/unpinning
it actually creates an asyc GPU error.
Only unpin tensor if it was pinned by ComfyUI (#10677)
Make ScaleROPE node work on Flux. (#10686)
Add logging for model unloading. (#10692)
Unload weights if vram usage goes up between runs. (#10690)
ops: Put weight cast on the offload stream (#10697)
This needs to be on the offload stream. This reproduced a black screen
with low resolution images on a slow bus when using FP8.
Update CI workflow to remove dead macOS runner. (#10704)
* Update CI workflow to remove dead macOS runner.
* revert
* revert
Don't pin tensor if not a torch.nn.parameter.Parameter (#10718)
Update README.md for Intel Arc GPU installation, remove IPEX (#10729)
IPEX is no longer needed for Intel Arc GPUs. Removing instruction to setup ipex.
mm/mp: always unload re-used but modified models (#10724)
The partial unloader path in model re-use flow skips straight to the
actual unload without any check of the patching UUID. This means that
if you do an upscale flow with a model patch on an existing model, it
will not apply your patchings.
Fix by delaying the partial_unload until after the uuid checks. This
is done by making partial_unload a model of partial_load where extra_mem
is -ve.
qwen: reduce VRAM usage (#10725)
Clean up a bunch of stacked and no-longer-needed tensors on the QWEN
VRAM peak (currently FFN).
With this I go from OOMing at B=37x1328x1328 to being able to
succesfully run B=47 (RTX5090).
Update Python 3.14 compatibility notes in README (#10730)
Quantized Ops fixes (#10715)
* offload support, bug fixes, remove mixins
* add readme
add PR template for API-Nodes (#10736)
feat: add create_time dict to prompt field in /history and /queue (#10741)
flux: reduce VRAM usage (#10737)
Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22
for 1600x1600 on RTX5090.
Better instructions for the portable. (#10743)
Use same code for chroma and flux blocks so that optimizations are shared. (#10746)
Fix custom nodes import error. (#10747)
This should fix the import errors but will break if the custom nodes actually try to use the class.
revert import reordering
revert imports pt 2
Add left padding support to tokenizers. (#10753)
chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)
Revert "chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)" (#10759)
This reverts commit 9a02382568.
Change ROCm nightly install command to 7.1 (#10764)
937 lines
38 KiB
Python
937 lines
38 KiB
Python
import asyncio
|
||
import contextlib
|
||
import json
|
||
import logging
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from io import BytesIO
|
||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||
from urllib.parse import urljoin, urlparse
|
||
|
||
import aiohttp
|
||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||
from pydantic import BaseModel
|
||
|
||
from comfy import utils
|
||
from comfy_api.latest import IO
|
||
from server import PromptServer
|
||
|
||
from . import request_logger
|
||
from ._helpers import (
|
||
default_base_url,
|
||
get_auth_header,
|
||
get_node_id,
|
||
is_processing_interrupted,
|
||
sleep_with_interrupt,
|
||
)
|
||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||
|
||
M = TypeVar("M", bound=BaseModel)
|
||
|
||
|
||
class ApiEndpoint:
|
||
def __init__(
|
||
self,
|
||
path: str,
|
||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||
*,
|
||
query_params: Optional[dict[str, Any]] = None,
|
||
headers: Optional[dict[str, str]] = None,
|
||
):
|
||
self.path = path
|
||
self.method = method
|
||
self.query_params = query_params or {}
|
||
self.headers = headers or {}
|
||
|
||
|
||
@dataclass
|
||
class _RequestConfig:
|
||
node_cls: type[IO.ComfyNode]
|
||
endpoint: ApiEndpoint
|
||
timeout: float
|
||
content_type: str
|
||
data: Optional[dict[str, Any]]
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||
multipart_parser: Optional[Callable]
|
||
max_retries: int
|
||
retry_delay: float
|
||
retry_backoff: float
|
||
wait_label: str = "Waiting"
|
||
monitor_progress: bool = True
|
||
estimated_total: Optional[int] = None
|
||
final_label_on_success: Optional[str] = "Completed"
|
||
progress_origin_ts: Optional[float] = None
|
||
|
||
|
||
@dataclass
|
||
class _PollUIState:
|
||
started: float
|
||
status_label: str = "Queued"
|
||
is_queued: bool = True
|
||
price: Optional[float] = None
|
||
estimated_duration: Optional[int] = None
|
||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||
|
||
|
||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
|
||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||
|
||
|
||
async def sync_op(
|
||
cls: type[IO.ComfyNode],
|
||
endpoint: ApiEndpoint,
|
||
*,
|
||
response_model: Type[M],
|
||
data: Optional[BaseModel] = None,
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||
content_type: str = "application/json",
|
||
timeout: float = 3600.0,
|
||
multipart_parser: Optional[Callable] = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff: float = 2.0,
|
||
wait_label: str = "Waiting for server",
|
||
estimated_duration: Optional[int] = None,
|
||
final_label_on_success: Optional[str] = "Completed",
|
||
progress_origin_ts: Optional[float] = None,
|
||
monitor_progress: bool = True,
|
||
) -> M:
|
||
raw = await sync_op_raw(
|
||
cls,
|
||
endpoint,
|
||
data=data,
|
||
files=files,
|
||
content_type=content_type,
|
||
timeout=timeout,
|
||
multipart_parser=multipart_parser,
|
||
max_retries=max_retries,
|
||
retry_delay=retry_delay,
|
||
retry_backoff=retry_backoff,
|
||
wait_label=wait_label,
|
||
estimated_duration=estimated_duration,
|
||
as_binary=False,
|
||
final_label_on_success=final_label_on_success,
|
||
progress_origin_ts=progress_origin_ts,
|
||
monitor_progress=monitor_progress,
|
||
)
|
||
if not isinstance(raw, dict):
|
||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||
return _validate_or_raise(response_model, raw)
|
||
|
||
|
||
async def poll_op(
|
||
cls: type[IO.ComfyNode],
|
||
poll_endpoint: ApiEndpoint,
|
||
*,
|
||
response_model: Type[M],
|
||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||
data: Optional[BaseModel] = None,
|
||
poll_interval: float = 5.0,
|
||
max_poll_attempts: int = 120,
|
||
timeout_per_poll: float = 120.0,
|
||
max_retries_per_poll: int = 3,
|
||
retry_delay_per_poll: float = 1.0,
|
||
retry_backoff_per_poll: float = 2.0,
|
||
estimated_duration: Optional[int] = None,
|
||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||
cancel_timeout: float = 10.0,
|
||
) -> M:
|
||
raw = await poll_op_raw(
|
||
cls,
|
||
poll_endpoint=poll_endpoint,
|
||
status_extractor=_wrap_model_extractor(response_model, status_extractor),
|
||
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
|
||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||
completed_statuses=completed_statuses,
|
||
failed_statuses=failed_statuses,
|
||
queued_statuses=queued_statuses,
|
||
data=data,
|
||
poll_interval=poll_interval,
|
||
max_poll_attempts=max_poll_attempts,
|
||
timeout_per_poll=timeout_per_poll,
|
||
max_retries_per_poll=max_retries_per_poll,
|
||
retry_delay_per_poll=retry_delay_per_poll,
|
||
retry_backoff_per_poll=retry_backoff_per_poll,
|
||
estimated_duration=estimated_duration,
|
||
cancel_endpoint=cancel_endpoint,
|
||
cancel_timeout=cancel_timeout,
|
||
)
|
||
if not isinstance(raw, dict):
|
||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||
return _validate_or_raise(response_model, raw)
|
||
|
||
|
||
async def sync_op_raw(
|
||
cls: type[IO.ComfyNode],
|
||
endpoint: ApiEndpoint,
|
||
*,
|
||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||
content_type: str = "application/json",
|
||
timeout: float = 3600.0,
|
||
multipart_parser: Optional[Callable] = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff: float = 2.0,
|
||
wait_label: str = "Waiting for server",
|
||
estimated_duration: Optional[int] = None,
|
||
as_binary: bool = False,
|
||
final_label_on_success: Optional[str] = "Completed",
|
||
progress_origin_ts: Optional[float] = None,
|
||
monitor_progress: bool = True,
|
||
) -> Union[dict[str, Any], bytes]:
|
||
"""
|
||
Make a single network request.
|
||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||
- If as_binary=True: returns bytes.
|
||
"""
|
||
if isinstance(data, BaseModel):
|
||
data = data.model_dump(exclude_none=True)
|
||
for k, v in list(data.items()):
|
||
if isinstance(v, Enum):
|
||
data[k] = v.value
|
||
cfg = _RequestConfig(
|
||
node_cls=cls,
|
||
endpoint=endpoint,
|
||
timeout=timeout,
|
||
content_type=content_type,
|
||
data=data,
|
||
files=files,
|
||
multipart_parser=multipart_parser,
|
||
max_retries=max_retries,
|
||
retry_delay=retry_delay,
|
||
retry_backoff=retry_backoff,
|
||
wait_label=wait_label,
|
||
monitor_progress=monitor_progress,
|
||
estimated_total=estimated_duration,
|
||
final_label_on_success=final_label_on_success,
|
||
progress_origin_ts=progress_origin_ts,
|
||
)
|
||
return await _request_base(cfg, expect_binary=as_binary)
|
||
|
||
|
||
async def poll_op_raw(
|
||
cls: type[IO.ComfyNode],
|
||
poll_endpoint: ApiEndpoint,
|
||
*,
|
||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||
poll_interval: float = 5.0,
|
||
max_poll_attempts: int = 120,
|
||
timeout_per_poll: float = 120.0,
|
||
max_retries_per_poll: int = 3,
|
||
retry_delay_per_poll: float = 1.0,
|
||
retry_backoff_per_poll: float = 2.0,
|
||
estimated_duration: Optional[int] = None,
|
||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||
cancel_timeout: float = 10.0,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
|
||
|
||
Uses default complete, failed and queued states assumption.
|
||
|
||
Returns the final JSON response from the poll endpoint.
|
||
"""
|
||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||
started = time.monotonic()
|
||
consumed_attempts = 0 # counts only non-queued polls
|
||
|
||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||
last_progress: Optional[int] = None
|
||
|
||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||
stop_ticker = asyncio.Event()
|
||
|
||
async def _ticker():
|
||
"""Emit a UI update every second while polling is in progress."""
|
||
try:
|
||
while not stop_ticker.is_set():
|
||
if is_processing_interrupted():
|
||
break
|
||
now = time.monotonic()
|
||
proc_elapsed = state.base_processing_elapsed + (
|
||
(now - state.active_since) if state.active_since is not None else 0.0
|
||
)
|
||
_display_time_progress(
|
||
cls,
|
||
status=state.status_label,
|
||
elapsed_seconds=int(now - state.started),
|
||
estimated_total=state.estimated_duration,
|
||
price=state.price,
|
||
is_queued=state.is_queued,
|
||
processing_elapsed_seconds=int(proc_elapsed),
|
||
)
|
||
await asyncio.sleep(1.0)
|
||
except Exception as exc:
|
||
logging.debug("Polling ticker exited: %s", exc)
|
||
|
||
ticker_task = asyncio.create_task(_ticker())
|
||
try:
|
||
while consumed_attempts < max_poll_attempts:
|
||
try:
|
||
resp_json = await sync_op_raw(
|
||
cls,
|
||
poll_endpoint,
|
||
data=data,
|
||
timeout=timeout_per_poll,
|
||
max_retries=max_retries_per_poll,
|
||
retry_delay=retry_delay_per_poll,
|
||
retry_backoff=retry_backoff_per_poll,
|
||
wait_label="Checking",
|
||
estimated_duration=None,
|
||
as_binary=False,
|
||
final_label_on_success=None,
|
||
monitor_progress=False,
|
||
)
|
||
if not isinstance(resp_json, dict):
|
||
raise Exception("Polling endpoint returned non-JSON response.")
|
||
except ProcessingInterrupted:
|
||
if cancel_endpoint:
|
||
with contextlib.suppress(Exception):
|
||
await sync_op_raw(
|
||
cls,
|
||
cancel_endpoint,
|
||
timeout=cancel_timeout,
|
||
max_retries=0,
|
||
wait_label="Cancelling task",
|
||
estimated_duration=None,
|
||
as_binary=False,
|
||
final_label_on_success=None,
|
||
monitor_progress=False,
|
||
)
|
||
raise
|
||
|
||
try:
|
||
status = _normalize_status_value(status_extractor(resp_json))
|
||
except Exception as e:
|
||
logging.error("Status extraction failed: %s", e)
|
||
status = None
|
||
|
||
if price_extractor:
|
||
new_price = price_extractor(resp_json)
|
||
if new_price is not None:
|
||
state.price = new_price
|
||
|
||
if progress_extractor:
|
||
new_progress = progress_extractor(resp_json)
|
||
if new_progress is not None and last_progress != new_progress:
|
||
progress_bar.update_absolute(new_progress, total=100)
|
||
last_progress = new_progress
|
||
|
||
now_ts = time.monotonic()
|
||
is_queued = status in queued_states
|
||
|
||
if is_queued:
|
||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||
state.base_processing_elapsed += now_ts - state.active_since
|
||
state.active_since = None
|
||
else:
|
||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||
state.active_since = now_ts
|
||
|
||
state.is_queued = is_queued
|
||
state.status_label = status or ("Queued" if is_queued else "Processing")
|
||
if status in completed_states:
|
||
if state.active_since is not None:
|
||
state.base_processing_elapsed += now_ts - state.active_since
|
||
state.active_since = None
|
||
stop_ticker.set()
|
||
with contextlib.suppress(Exception):
|
||
await ticker_task
|
||
|
||
if progress_bar and last_progress != 100:
|
||
progress_bar.update_absolute(100, total=100)
|
||
|
||
_display_time_progress(
|
||
cls,
|
||
status=status if status else "Completed",
|
||
elapsed_seconds=int(now_ts - started),
|
||
estimated_total=estimated_duration,
|
||
price=state.price,
|
||
is_queued=False,
|
||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||
)
|
||
return resp_json
|
||
|
||
if status in failed_states:
|
||
msg = f"Task failed: {json.dumps(resp_json)}"
|
||
logging.error(msg)
|
||
raise Exception(msg)
|
||
|
||
try:
|
||
await sleep_with_interrupt(poll_interval, cls, None, None, None)
|
||
except ProcessingInterrupted:
|
||
if cancel_endpoint:
|
||
with contextlib.suppress(Exception):
|
||
await sync_op_raw(
|
||
cls,
|
||
cancel_endpoint,
|
||
timeout=cancel_timeout,
|
||
max_retries=0,
|
||
wait_label="Cancelling task",
|
||
estimated_duration=None,
|
||
as_binary=False,
|
||
final_label_on_success=None,
|
||
monitor_progress=False,
|
||
)
|
||
raise
|
||
if not is_queued:
|
||
consumed_attempts += 1
|
||
|
||
raise Exception(
|
||
f"Polling timed out after {max_poll_attempts} non-queued attempts "
|
||
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
|
||
)
|
||
except ProcessingInterrupted:
|
||
raise
|
||
except (LocalNetworkError, ApiServerError):
|
||
raise
|
||
except Exception as e:
|
||
raise Exception(f"Polling aborted due to error: {e}") from e
|
||
finally:
|
||
stop_ticker.set()
|
||
with contextlib.suppress(Exception):
|
||
await ticker_task
|
||
|
||
|
||
def _display_text(
|
||
node_cls: type[IO.ComfyNode],
|
||
text: Optional[str],
|
||
*,
|
||
status: Optional[Union[str, int]] = None,
|
||
price: Optional[float] = None,
|
||
) -> None:
|
||
display_lines: list[str] = []
|
||
if status:
|
||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||
if price is not None:
|
||
display_lines.append(f"Price: ${float(price):,.4f}")
|
||
if text is not None:
|
||
display_lines.append(text)
|
||
if display_lines:
|
||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||
|
||
|
||
def _display_time_progress(
|
||
node_cls: type[IO.ComfyNode],
|
||
status: Optional[Union[str, int]],
|
||
elapsed_seconds: int,
|
||
estimated_total: Optional[int] = None,
|
||
*,
|
||
price: Optional[float] = None,
|
||
is_queued: Optional[bool] = None,
|
||
processing_elapsed_seconds: Optional[int] = None,
|
||
) -> None:
|
||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||
remaining = max(0, int(estimated_total) - int(pe))
|
||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||
else:
|
||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||
_display_text(node_cls, time_line, status=status, price=price)
|
||
|
||
|
||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||
results = {
|
||
"internet_accessible": False,
|
||
"api_accessible": False,
|
||
}
|
||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
with contextlib.suppress(ClientError, OSError):
|
||
async with session.get("https://www.google.com") as resp:
|
||
results["internet_accessible"] = resp.status < 500
|
||
if not results["internet_accessible"]:
|
||
return results
|
||
|
||
parsed = urlparse(default_base_url())
|
||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||
with contextlib.suppress(ClientError, OSError):
|
||
async with session.get(health_url) as resp:
|
||
results["api_accessible"] = resp.status < 500
|
||
return results
|
||
|
||
|
||
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||
"""Normalize (filename, value, content_type)."""
|
||
if len(t) == 2:
|
||
return t[0], t[1], "application/octet-stream"
|
||
if len(t) == 3:
|
||
return t[0], t[1], t[2]
|
||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||
|
||
|
||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||
params = dict(endpoint_params or {})
|
||
if method.upper() == "GET" and data:
|
||
for k, v in data.items():
|
||
if v is not None:
|
||
params[k] = v
|
||
return params
|
||
|
||
|
||
def _friendly_http_message(status: int, body: Any) -> str:
|
||
if status == 401:
|
||
return "Unauthorized: Please login first to use this node."
|
||
if status == 402:
|
||
return "Payment Required: Please add credits to your account to use this node."
|
||
if status == 409:
|
||
return "There is a problem with your account. Please contact support@comfy.org."
|
||
if status == 429:
|
||
return "Rate Limit Exceeded: Please try again later."
|
||
try:
|
||
if isinstance(body, dict):
|
||
err = body.get("error")
|
||
if isinstance(err, dict):
|
||
msg = err.get("message")
|
||
typ = err.get("type")
|
||
if msg and typ:
|
||
return f"API Error: {msg} (Type: {typ})"
|
||
if msg:
|
||
return f"API Error: {msg}"
|
||
return f"API Error: {json.dumps(body)}"
|
||
else:
|
||
txt = str(body)
|
||
if len(txt) <= 200:
|
||
return f"API Error (raw): {txt}"
|
||
return f"API Error (status {status})"
|
||
except Exception:
|
||
return f"HTTP {status}: Unknown error"
|
||
|
||
|
||
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||
slug = path.strip("/").replace("/", "_") or "op"
|
||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||
|
||
|
||
def _snapshot_request_body_for_logging(
|
||
content_type: str,
|
||
method: str,
|
||
data: Optional[dict[str, Any]],
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||
) -> Optional[Union[dict[str, Any], str]]:
|
||
if method.upper() == "GET":
|
||
return None
|
||
if content_type == "multipart/form-data":
|
||
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
|
||
file_fields: list[dict[str, str]] = []
|
||
if files:
|
||
file_iter = files if isinstance(files, list) else list(files.items())
|
||
for field_name, file_obj in file_iter:
|
||
if file_obj is None:
|
||
continue
|
||
if isinstance(file_obj, tuple):
|
||
filename = file_obj[0]
|
||
else:
|
||
filename = getattr(file_obj, "name", field_name)
|
||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||
if content_type == "application/x-www-form-urlencoded":
|
||
return data or {}
|
||
return data or {}
|
||
|
||
|
||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||
url = cfg.endpoint.path
|
||
parsed_url = urlparse(url)
|
||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||
|
||
method = cfg.endpoint.method
|
||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||
|
||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||
"""Every second: update elapsed time and signal interruption."""
|
||
try:
|
||
while not stop_evt.is_set():
|
||
if is_processing_interrupted():
|
||
return
|
||
if cfg.monitor_progress:
|
||
_display_time_progress(
|
||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||
)
|
||
await asyncio.sleep(1.0)
|
||
except asyncio.CancelledError:
|
||
return # normal shutdown
|
||
|
||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||
attempt = 0
|
||
delay = cfg.retry_delay
|
||
operation_succeeded: bool = False
|
||
final_elapsed_seconds: Optional[int] = None
|
||
while True:
|
||
attempt += 1
|
||
stop_event = asyncio.Event()
|
||
monitor_task: Optional[asyncio.Task] = None
|
||
sess: Optional[aiohttp.ClientSession] = None
|
||
|
||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||
|
||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||
if cfg.endpoint.headers:
|
||
payload_headers.update(cfg.endpoint.headers)
|
||
|
||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||
if method == "GET":
|
||
payload_headers.pop("Content-Type", None)
|
||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||
try:
|
||
if cfg.monitor_progress:
|
||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||
|
||
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
|
||
sess = aiohttp.ClientSession(timeout=timeout)
|
||
|
||
if cfg.content_type == "multipart/form-data" and method != "GET":
|
||
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
|
||
payload_headers.pop("Content-Type", None)
|
||
if cfg.multipart_parser and cfg.data:
|
||
form = cfg.multipart_parser(cfg.data)
|
||
if not isinstance(form, aiohttp.FormData):
|
||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||
else:
|
||
form = aiohttp.FormData(default_to_multipart=True)
|
||
if cfg.data:
|
||
for k, v in cfg.data.items():
|
||
if v is None:
|
||
continue
|
||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||
if cfg.files:
|
||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||
for field_name, file_obj in file_iter:
|
||
if file_obj is None:
|
||
continue
|
||
if isinstance(file_obj, tuple):
|
||
filename, file_value, content_type = _unpack_tuple(file_obj)
|
||
else:
|
||
filename = getattr(file_obj, "name", field_name)
|
||
file_value = file_obj
|
||
content_type = "application/octet-stream"
|
||
# Attempt to rewind BytesIO for retries
|
||
if isinstance(file_value, BytesIO):
|
||
with contextlib.suppress(Exception):
|
||
file_value.seek(0)
|
||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||
payload_kw["data"] = form
|
||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||
payload_kw["data"] = cfg.data or {}
|
||
elif method != "GET":
|
||
payload_headers["Content-Type"] = "application/json"
|
||
payload_kw["json"] = cfg.data or {}
|
||
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||
|
||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||
req_task = asyncio.create_task(req_coro)
|
||
|
||
# Race: request vs. monitor (interruption)
|
||
tasks = {req_task}
|
||
if monitor_task:
|
||
tasks.add(monitor_task)
|
||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||
|
||
if monitor_task and monitor_task in done:
|
||
# Interrupted – cancel the request and abort
|
||
if req_task in pending:
|
||
req_task.cancel()
|
||
raise ProcessingInterrupted("Task cancelled")
|
||
|
||
# Otherwise, request finished
|
||
resp = await req_task
|
||
async with resp:
|
||
if resp.status >= 400:
|
||
try:
|
||
body = await resp.json()
|
||
except (ContentTypeError, json.JSONDecodeError):
|
||
body = await resp.text()
|
||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||
logging.warning(
|
||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||
method,
|
||
url,
|
||
resp.status,
|
||
delay,
|
||
attempt,
|
||
cfg.max_retries,
|
||
)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=body,
|
||
error_message=_friendly_http_message(resp.status, body),
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
|
||
await sleep_with_interrupt(
|
||
delay,
|
||
cfg.node_cls,
|
||
cfg.wait_label if cfg.monitor_progress else None,
|
||
start_time if cfg.monitor_progress else None,
|
||
cfg.estimated_total,
|
||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||
)
|
||
delay *= cfg.retry_backoff
|
||
continue
|
||
msg = _friendly_http_message(resp.status, body)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=body,
|
||
error_message=msg,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
raise Exception(msg)
|
||
|
||
if expect_binary:
|
||
buff = bytearray()
|
||
last_tick = time.monotonic()
|
||
async for chunk in resp.content.iter_chunked(64 * 1024):
|
||
buff.extend(chunk)
|
||
now = time.monotonic()
|
||
if now - last_tick >= 1.0:
|
||
last_tick = now
|
||
if is_processing_interrupted():
|
||
raise ProcessingInterrupted("Task cancelled")
|
||
if cfg.monitor_progress:
|
||
_display_time_progress(
|
||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||
)
|
||
bytes_payload = bytes(buff)
|
||
operation_succeeded = True
|
||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=bytes_payload,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
return bytes_payload
|
||
else:
|
||
try:
|
||
payload = await resp.json()
|
||
response_content_to_log: Any = payload
|
||
except (ContentTypeError, json.JSONDecodeError):
|
||
text = await resp.text()
|
||
try:
|
||
payload = json.loads(text) if text else {}
|
||
except json.JSONDecodeError:
|
||
payload = {"_raw": text}
|
||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||
operation_succeeded = True
|
||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=response_content_to_log,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
return payload
|
||
|
||
except ProcessingInterrupted:
|
||
logging.debug("Polling was interrupted by user")
|
||
raise
|
||
except (ClientError, OSError) as e:
|
||
if attempt <= cfg.max_retries:
|
||
logging.warning(
|
||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||
method,
|
||
url,
|
||
delay,
|
||
attempt,
|
||
cfg.max_retries,
|
||
str(e),
|
||
)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||
await sleep_with_interrupt(
|
||
delay,
|
||
cfg.node_cls,
|
||
cfg.wait_label if cfg.monitor_progress else None,
|
||
start_time if cfg.monitor_progress else None,
|
||
cfg.estimated_total,
|
||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||
)
|
||
delay *= cfg.retry_backoff
|
||
continue
|
||
diag = await _diagnose_connectivity()
|
||
if not diag["internet_accessible"]:
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
error_message=f"LocalNetworkError: {str(e)}",
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||
raise LocalNetworkError(
|
||
"Unable to connect to the API server due to local network issues. "
|
||
"Please check your internet connection and try again."
|
||
) from e
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
error_message=f"ApiServerError: {str(e)}",
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||
raise ApiServerError(
|
||
f"The API server at {default_base_url()} is currently unreachable. "
|
||
f"The service may be experiencing issues."
|
||
) from e
|
||
finally:
|
||
stop_event.set()
|
||
if monitor_task:
|
||
monitor_task.cancel()
|
||
with contextlib.suppress(Exception):
|
||
await monitor_task
|
||
if sess:
|
||
with contextlib.suppress(Exception):
|
||
await sess.close()
|
||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||
_display_time_progress(
|
||
cfg.node_cls,
|
||
status=cfg.final_label_on_success,
|
||
elapsed_seconds=(
|
||
final_elapsed_seconds
|
||
if final_elapsed_seconds is not None
|
||
else int(time.monotonic() - start_time)
|
||
),
|
||
estimated_total=cfg.estimated_total,
|
||
price=None,
|
||
is_queued=False,
|
||
processing_elapsed_seconds=final_elapsed_seconds,
|
||
)
|
||
|
||
|
||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||
try:
|
||
return response_model.model_validate(payload)
|
||
except Exception as e:
|
||
logging.error(
|
||
"Response validation failed for %s: %s",
|
||
getattr(response_model, "__name__", response_model),
|
||
e,
|
||
)
|
||
raise Exception(
|
||
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
|
||
) from e
|
||
|
||
|
||
def _wrap_model_extractor(
|
||
response_model: Type[M],
|
||
extractor: Optional[Callable[[M], Any]],
|
||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||
Validates the dict into `response_model` before invoking `extractor`.
|
||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||
the same response for multiple extractors in a single poll attempt.
|
||
"""
|
||
if extractor is None:
|
||
return None
|
||
_cache: dict[int, M] = {}
|
||
|
||
def _wrapped(d: dict[str, Any]) -> Any:
|
||
try:
|
||
key = id(d)
|
||
model = _cache.get(key)
|
||
if model is None:
|
||
model = response_model.model_validate(d)
|
||
_cache[key] = model
|
||
return extractor(model)
|
||
except Exception as e:
|
||
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
|
||
raise
|
||
|
||
return _wrapped
|
||
|
||
|
||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||
if not values:
|
||
return set()
|
||
out: set[Union[str, int]] = set()
|
||
for v in values:
|
||
nv = _normalize_status_value(v)
|
||
if nv is not None:
|
||
out.add(nv)
|
||
return out
|
||
|
||
|
||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||
if isinstance(val, str):
|
||
return val.strip().lower()
|
||
return val
|