mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-01 11:57:24 +08:00
feat(api-nodes): plumb auth_headers, base_url, error_parser, rate_limit_label through sync_op/poll_op
- Add auth_headers/base_url passthrough so RNP and other clients can override the default Comfy auth + base URL on a per-request basis without a node_cls. - Add _parse_retry_after helper that honors RFC 7231 Retry-After (seconds or HTTP-date), feeding the existing 429 / SERVER_BUSY / MAINTENANCE backoff. - Add rate_limit_label callback so callers can render a friendlier per-second status (e.g. 'Server busy, retrying in 30s...') during a rate-limit sleep; the in-flight monitor task is paused so the two writers don't race. - Add error_parser callback so structured protocol errors (e.g. RNP RnpProtocolError) bubble unchanged instead of being flattened by _friendly_http_message; typed errors are also re-raised from poll_op_raw. - Add allow_304 path returning None on conditional GETs. - Allow node_cls=None for non-workflow callers; _display_text becomes a no-op. - _diagnose_connectivity / ApiServerError respect the resolved base_url. Amp-Thread-ID: https://ampcode.com/threads/T-019e1889-d8bd-732f-8170-b85fd94da503 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
428c323780
commit
c899ea4ef7
@ -51,7 +51,7 @@ class ApiEndpoint:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _RequestConfig:
|
class _RequestConfig:
|
||||||
node_cls: type[IO.ComfyNode]
|
node_cls: type[IO.ComfyNode] | None
|
||||||
endpoint: ApiEndpoint
|
endpoint: ApiEndpoint
|
||||||
timeout: float
|
timeout: float
|
||||||
content_type: str
|
content_type: str
|
||||||
@ -70,6 +70,17 @@ class _RequestConfig:
|
|||||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
||||||
|
base_url: str | None = None
|
||||||
|
auth_headers: dict[str, str] | None = None
|
||||||
|
allow_304: bool = False
|
||||||
|
error_parser: Callable[[int, Any], Exception | None] | None = None
|
||||||
|
# Optional callback to render a per-second progress label while
|
||||||
|
# waiting out a rate-limit / SERVER_BUSY / MAINTENANCE retry. Called
|
||||||
|
# with ``(status, body, retry_after_s)`` and should return the label
|
||||||
|
# string used by ``_display_time_progress`` (which renders it as
|
||||||
|
# ``Status: <label>\nTime elapsed: Ns``). Returning ``None`` keeps
|
||||||
|
# the default ``cfg.wait_label``.
|
||||||
|
rate_limit_label: Callable[[int, Any, float], str | None] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -84,13 +95,40 @@ class _PollUIState:
|
|||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_retry_after(raw: str | None) -> float | None:
|
||||||
|
"""RFC 7231 Retry-After: seconds-int or HTTP-date.
|
||||||
|
|
||||||
|
Returns the wait time in seconds, clamped to non-negative. Returns
|
||||||
|
``None`` for unparseable / missing values so the caller can fall
|
||||||
|
back to the local backoff schedule.
|
||||||
|
"""
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
raw = raw.strip()
|
||||||
|
try:
|
||||||
|
return max(0.0, float(raw))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
# HTTP-date form (rare in practice for our servers, but cheap to support).
|
||||||
|
try:
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
dt = parsedate_to_datetime(raw)
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
import datetime as _dt
|
||||||
|
now = _dt.datetime.now(tz=dt.tzinfo) if dt.tzinfo else _dt.datetime.utcnow()
|
||||||
|
return max(0.0, (dt - now).total_seconds())
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||||
|
|
||||||
|
|
||||||
async def sync_op(
|
async def sync_op(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode] | None,
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: type[M],
|
response_model: type[M],
|
||||||
@ -110,6 +148,9 @@ async def sync_op(
|
|||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
max_retries_on_rate_limit: int = 16,
|
max_retries_on_rate_limit: int = 16,
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
|
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
auth_headers: dict[str, str] | None = None,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await sync_op_raw(
|
raw = await sync_op_raw(
|
||||||
cls,
|
cls,
|
||||||
@ -131,6 +172,9 @@ async def sync_op(
|
|||||||
monitor_progress=monitor_progress,
|
monitor_progress=monitor_progress,
|
||||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||||
is_rate_limited=is_rate_limited,
|
is_rate_limited=is_rate_limited,
|
||||||
|
rate_limit_label=rate_limit_label,
|
||||||
|
base_url=base_url,
|
||||||
|
auth_headers=auth_headers,
|
||||||
)
|
)
|
||||||
if not isinstance(raw, dict):
|
if not isinstance(raw, dict):
|
||||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||||
@ -138,7 +182,7 @@ async def sync_op(
|
|||||||
|
|
||||||
|
|
||||||
async def poll_op(
|
async def poll_op(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode] | None,
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: type[M],
|
response_model: type[M],
|
||||||
@ -159,6 +203,11 @@ async def poll_op(
|
|||||||
cancel_endpoint: ApiEndpoint | None = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
extra_text: str | None = None,
|
extra_text: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
auth_headers: dict[str, str] | None = None,
|
||||||
|
error_parser: Callable[[int, Any], Exception | None] | None = None,
|
||||||
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
|
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await poll_op_raw(
|
raw = await poll_op_raw(
|
||||||
cls,
|
cls,
|
||||||
@ -180,6 +229,11 @@ async def poll_op(
|
|||||||
cancel_endpoint=cancel_endpoint,
|
cancel_endpoint=cancel_endpoint,
|
||||||
cancel_timeout=cancel_timeout,
|
cancel_timeout=cancel_timeout,
|
||||||
extra_text=extra_text,
|
extra_text=extra_text,
|
||||||
|
base_url=base_url,
|
||||||
|
auth_headers=auth_headers,
|
||||||
|
error_parser=error_parser,
|
||||||
|
is_rate_limited=is_rate_limited,
|
||||||
|
rate_limit_label=rate_limit_label,
|
||||||
)
|
)
|
||||||
if not isinstance(raw, dict):
|
if not isinstance(raw, dict):
|
||||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||||
@ -187,7 +241,7 @@ async def poll_op(
|
|||||||
|
|
||||||
|
|
||||||
async def sync_op_raw(
|
async def sync_op_raw(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode] | None,
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
@ -207,13 +261,26 @@ async def sync_op_raw(
|
|||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
max_retries_on_rate_limit: int = 16,
|
max_retries_on_rate_limit: int = 16,
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
|
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
||||||
) -> dict[str, Any] | bytes:
|
base_url: str | None = None,
|
||||||
|
auth_headers: dict[str, str] | None = None,
|
||||||
|
allow_304: bool = False,
|
||||||
|
error_parser: Callable[[int, Any], Exception | None] | None = None,
|
||||||
|
) -> dict[str, Any] | bytes | None:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
- If as_binary=True: returns bytes.
|
- If as_binary=True: returns bytes.
|
||||||
- response_header_validator: optional callback receiving response headers dict
|
- response_header_validator: optional callback receiving response headers dict
|
||||||
|
- base_url: override the default api.comfy.org base for this request.
|
||||||
|
- auth_headers: pre-built Authorization/X-API-KEY dict; bypasses get_auth_header.
|
||||||
|
- allow_304: when True, an HTTP 304 response returns ``None`` instead of raising.
|
||||||
|
- error_parser: when set, called on every >=400 response with
|
||||||
|
``(status, body)``; if it returns an Exception, that exception
|
||||||
|
is raised immediately and the retry/friendly-message path is
|
||||||
|
skipped. Used by RNP to surface structured ``RnpProtocolError``
|
||||||
|
envelopes that would otherwise be flattened to "API Error: ...".
|
||||||
"""
|
"""
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.model_dump(exclude_none=True)
|
data = data.model_dump(exclude_none=True)
|
||||||
@ -239,13 +306,18 @@ async def sync_op_raw(
|
|||||||
price_extractor=price_extractor,
|
price_extractor=price_extractor,
|
||||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||||
is_rate_limited=is_rate_limited,
|
is_rate_limited=is_rate_limited,
|
||||||
|
rate_limit_label=rate_limit_label,
|
||||||
response_header_validator=response_header_validator,
|
response_header_validator=response_header_validator,
|
||||||
|
base_url=base_url,
|
||||||
|
auth_headers=auth_headers,
|
||||||
|
allow_304=allow_304,
|
||||||
|
error_parser=error_parser,
|
||||||
)
|
)
|
||||||
return await _request_base(cfg, expect_binary=as_binary)
|
return await _request_base(cfg, expect_binary=as_binary)
|
||||||
|
|
||||||
|
|
||||||
async def poll_op_raw(
|
async def poll_op_raw(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode] | None,
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||||
@ -265,6 +337,11 @@ async def poll_op_raw(
|
|||||||
cancel_endpoint: ApiEndpoint | None = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
extra_text: str | None = None,
|
extra_text: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
auth_headers: dict[str, str] | None = None,
|
||||||
|
error_parser: Callable[[int, Any], Exception | None] | None = None,
|
||||||
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
|
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||||
@ -272,6 +349,14 @@ async def poll_op_raw(
|
|||||||
|
|
||||||
Uses default complete, failed and queued states assumption.
|
Uses default complete, failed and queued states assumption.
|
||||||
|
|
||||||
|
``error_parser`` and ``is_rate_limited`` are forwarded to each
|
||||||
|
per-poll ``sync_op_raw`` call so callers can surface a typed
|
||||||
|
exception for >=400 responses (e.g. an RNP structured-error
|
||||||
|
envelope) and treat protocol-specific 5xx codes (e.g. RNP
|
||||||
|
``SERVER_BUSY`` / ``MAINTENANCE``) like a 429 — both honour
|
||||||
|
``Retry-After`` and consume the rate-limit retry counter instead
|
||||||
|
of falling through the generic 5xx exponential-backoff path.
|
||||||
|
|
||||||
Returns the final JSON response from the poll endpoint.
|
Returns the final JSON response from the poll endpoint.
|
||||||
"""
|
"""
|
||||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||||
@ -286,6 +371,22 @@ async def poll_op_raw(
|
|||||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||||
stop_ticker = asyncio.Event()
|
stop_ticker = asyncio.Event()
|
||||||
|
|
||||||
|
# Wrap the user's rate_limit_label so a SERVER_BUSY/MAINTENANCE/429
|
||||||
|
# wait inside the per-poll sync_op_raw also flips the outer ticker's
|
||||||
|
# status_label — otherwise _ticker keeps writing "Status: Queued"
|
||||||
|
# over our message every second. The next successful poll resets
|
||||||
|
# status_label from the response, so no manual restore is needed.
|
||||||
|
user_rate_limit_label = rate_limit_label
|
||||||
|
|
||||||
|
def _wrapped_rate_limit_label(status: int, body: Any, retry_after_s: float) -> str | None:
|
||||||
|
label: str | None = None
|
||||||
|
if user_rate_limit_label is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
label = user_rate_limit_label(status, body, retry_after_s)
|
||||||
|
if label:
|
||||||
|
state.status_label = label
|
||||||
|
return label
|
||||||
|
|
||||||
async def _ticker():
|
async def _ticker():
|
||||||
"""Emit a UI update every second while polling is in progress."""
|
"""Emit a UI update every second while polling is in progress."""
|
||||||
try:
|
try:
|
||||||
@ -327,6 +428,11 @@ async def poll_op_raw(
|
|||||||
as_binary=False,
|
as_binary=False,
|
||||||
final_label_on_success=None,
|
final_label_on_success=None,
|
||||||
monitor_progress=False,
|
monitor_progress=False,
|
||||||
|
base_url=base_url,
|
||||||
|
auth_headers=auth_headers,
|
||||||
|
error_parser=error_parser,
|
||||||
|
is_rate_limited=is_rate_limited,
|
||||||
|
rate_limit_label=_wrapped_rate_limit_label,
|
||||||
)
|
)
|
||||||
if not isinstance(resp_json, dict):
|
if not isinstance(resp_json, dict):
|
||||||
raise Exception("Polling endpoint returned non-JSON response.")
|
raise Exception("Polling endpoint returned non-JSON response.")
|
||||||
@ -343,6 +449,8 @@ async def poll_op_raw(
|
|||||||
as_binary=False,
|
as_binary=False,
|
||||||
final_label_on_success=None,
|
final_label_on_success=None,
|
||||||
monitor_progress=False,
|
monitor_progress=False,
|
||||||
|
base_url=base_url,
|
||||||
|
auth_headers=auth_headers,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -419,6 +527,8 @@ async def poll_op_raw(
|
|||||||
as_binary=False,
|
as_binary=False,
|
||||||
final_label_on_success=None,
|
final_label_on_success=None,
|
||||||
monitor_progress=False,
|
monitor_progress=False,
|
||||||
|
base_url=base_url,
|
||||||
|
auth_headers=auth_headers,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
if not is_queued:
|
if not is_queued:
|
||||||
@ -433,6 +543,16 @@ async def poll_op_raw(
|
|||||||
except (LocalNetworkError, ApiServerError):
|
except (LocalNetworkError, ApiServerError):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Let typed protocol errors raised by ``error_parser`` (e.g.
|
||||||
|
# RnpProtocolError) bubble unchanged so callers can pattern-
|
||||||
|
# match on ``.code`` to drive resume / fallback logic. Any
|
||||||
|
# exception that exposes a string ``code`` attribute counts as
|
||||||
|
# "typed" — duck-typing avoids importing the typed-error class
|
||||||
|
# into this generic util layer. Everything else gets the
|
||||||
|
# friendlier wrapper for back-compat with existing api-node
|
||||||
|
# callers that surface the wrapped message directly to users.
|
||||||
|
if isinstance(getattr(e, "code", None), str):
|
||||||
|
raise
|
||||||
raise Exception(f"Polling aborted due to error: {e}") from e
|
raise Exception(f"Polling aborted due to error: {e}") from e
|
||||||
finally:
|
finally:
|
||||||
stop_ticker.set()
|
stop_ticker.set()
|
||||||
@ -441,12 +561,16 @@ async def poll_op_raw(
|
|||||||
|
|
||||||
|
|
||||||
def _display_text(
|
def _display_text(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode] | None,
|
||||||
text: str | None,
|
text: str | None,
|
||||||
*,
|
*,
|
||||||
status: str | int | None = None,
|
status: str | int | None = None,
|
||||||
price: float | None = None,
|
price: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Skip when there's no node to address — RNP / bootstrap callers
|
||||||
|
# pass cls=None on requests that aren't tied to a workflow node.
|
||||||
|
if node_cls is None:
|
||||||
|
return
|
||||||
display_lines: list[str] = []
|
display_lines: list[str] = []
|
||||||
if status:
|
if status:
|
||||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||||
@ -461,7 +585,7 @@ def _display_text(
|
|||||||
|
|
||||||
|
|
||||||
def _display_time_progress(
|
def _display_time_progress(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode] | None,
|
||||||
status: str | int | None,
|
status: str | int | None,
|
||||||
elapsed_seconds: int,
|
elapsed_seconds: int,
|
||||||
estimated_total: int | None = None,
|
estimated_total: int | None = None,
|
||||||
@ -481,7 +605,7 @@ def _display_time_progress(
|
|||||||
_display_text(node_cls, text, status=status, price=price)
|
_display_text(node_cls, text, status=status, price=price)
|
||||||
|
|
||||||
|
|
||||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
async def _diagnose_connectivity(base_url: str | None = None) -> dict[str, bool]:
|
||||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||||
results = {
|
results = {
|
||||||
"internet_accessible": False,
|
"internet_accessible": False,
|
||||||
@ -515,7 +639,7 @@ async def _diagnose_connectivity() -> dict[str, bool]:
|
|||||||
if not results["internet_accessible"]:
|
if not results["internet_accessible"]:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
parsed = urlparse(default_base_url())
|
parsed = urlparse(base_url or default_base_url())
|
||||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||||
with contextlib.suppress(ClientError, OSError):
|
with contextlib.suppress(ClientError, OSError):
|
||||||
async with session.get(health_url) as resp:
|
async with session.get(health_url) as resp:
|
||||||
@ -604,10 +728,11 @@ def _snapshot_request_body_for_logging(
|
|||||||
|
|
||||||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||||||
|
resolved_base_url = cfg.base_url or default_base_url()
|
||||||
url = cfg.endpoint.path
|
url = cfg.endpoint.path
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
url = urljoin(resolved_base_url.rstrip("/") + "/", url.lstrip("/"))
|
||||||
|
|
||||||
method = cfg.endpoint.method
|
method = cfg.endpoint.method
|
||||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||||
@ -645,7 +770,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
|
|
||||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
if cfg.auth_headers is not None:
|
||||||
|
payload_headers.update(cfg.auth_headers)
|
||||||
|
elif cfg.node_cls is not None:
|
||||||
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
payload_headers.update(cfg.endpoint.headers)
|
payload_headers.update(cfg.endpoint.headers)
|
||||||
@ -726,6 +854,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
# Otherwise, request finished
|
# Otherwise, request finished
|
||||||
resp = await req_task
|
resp = await req_task
|
||||||
async with resp:
|
async with resp:
|
||||||
|
if cfg.allow_304 and resp.status == 304:
|
||||||
|
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||||
|
if cfg.response_header_validator:
|
||||||
|
cfg.response_header_validator(resp_headers)
|
||||||
|
operation_succeeded = True
|
||||||
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=resp_headers,
|
||||||
|
response_content=None,
|
||||||
|
)
|
||||||
|
return None
|
||||||
if resp.status >= 400:
|
if resp.status >= 400:
|
||||||
try:
|
try:
|
||||||
body = await resp.json()
|
body = await resp.json()
|
||||||
@ -737,12 +880,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
is_rl = resp.status == 429 or (
|
is_rl = resp.status == 429 or (
|
||||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||||
)
|
)
|
||||||
|
sleep_label = cfg.wait_label
|
||||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||||
rate_limit_attempts += 1
|
rate_limit_attempts += 1
|
||||||
wait_time = min(rate_limit_delay, 30.0)
|
# Honor server-provided Retry-After when present
|
||||||
|
# (clamped to keep a runaway header from blocking
|
||||||
|
# the executor for hours), otherwise fall back to
|
||||||
|
# the local exponential backoff.
|
||||||
|
retry_after_s = _parse_retry_after(resp.headers.get("Retry-After"))
|
||||||
|
if retry_after_s is not None:
|
||||||
|
wait_time = min(retry_after_s, 300.0)
|
||||||
|
else:
|
||||||
|
wait_time = min(rate_limit_delay, 30.0)
|
||||||
rate_limit_delay *= cfg.retry_backoff
|
rate_limit_delay *= cfg.retry_backoff
|
||||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||||
should_retry = True
|
should_retry = True
|
||||||
|
# Let callers (e.g. RNP) render a friendlier
|
||||||
|
# per-second label like
|
||||||
|
# "Server busy, retrying in 30s..." while we
|
||||||
|
# sleep — surfaced via send_progress_text by
|
||||||
|
# _display_time_progress every second.
|
||||||
|
if cfg.rate_limit_label is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
custom = cfg.rate_limit_label(
|
||||||
|
resp.status, body, wait_time
|
||||||
|
)
|
||||||
|
if custom:
|
||||||
|
sleep_label = custom
|
||||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||||
wait_time = delay
|
wait_time = delay
|
||||||
delay *= cfg.retry_backoff
|
delay *= cfg.retry_backoff
|
||||||
@ -767,15 +931,45 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
response_content=body,
|
response_content=body,
|
||||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||||
)
|
)
|
||||||
|
# Stop the in-flight monitor so the per-second
|
||||||
|
# progress label flips from cfg.wait_label
|
||||||
|
# ("Waiting for server") to the rate-limit copy
|
||||||
|
# ("Server busy, retrying in Ns...") for the
|
||||||
|
# duration of this sleep — otherwise the two
|
||||||
|
# writers race and the user sees alternating
|
||||||
|
# text every tick.
|
||||||
|
stop_event.set()
|
||||||
|
if monitor_task:
|
||||||
|
monitor_task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await monitor_task
|
||||||
|
monitor_task = None
|
||||||
await sleep_with_interrupt(
|
await sleep_with_interrupt(
|
||||||
wait_time,
|
wait_time,
|
||||||
cfg.node_cls,
|
cfg.node_cls,
|
||||||
cfg.wait_label if cfg.monitor_progress else None,
|
sleep_label if cfg.monitor_progress else None,
|
||||||
start_time if cfg.monitor_progress else None,
|
start_time if cfg.monitor_progress else None,
|
||||||
cfg.estimated_total,
|
cfg.estimated_total,
|
||||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
# Retries either weren't applicable or have been exhausted —
|
||||||
|
# give the caller's error_parser a chance to surface a
|
||||||
|
# structured exception (e.g. RNP RnpProtocolError) before
|
||||||
|
# we flatten the response with _friendly_http_message.
|
||||||
|
if cfg.error_parser is not None:
|
||||||
|
custom_exc = cfg.error_parser(resp.status, body)
|
||||||
|
if custom_exc is not None:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=body,
|
||||||
|
error_message=f"{type(custom_exc).__name__}: {custom_exc}",
|
||||||
|
)
|
||||||
|
raise custom_exc
|
||||||
msg = _friendly_http_message(resp.status, body)
|
msg = _friendly_http_message(resp.status, body)
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
@ -878,7 +1072,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
)
|
)
|
||||||
delay *= cfg.retry_backoff
|
delay *= cfg.retry_backoff
|
||||||
continue
|
continue
|
||||||
diag = await _diagnose_connectivity()
|
diag = await _diagnose_connectivity(resolved_base_url)
|
||||||
if not diag["internet_accessible"]:
|
if not diag["internet_accessible"]:
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
@ -903,7 +1097,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
error_message=f"ApiServerError: {str(e)}",
|
error_message=f"ApiServerError: {str(e)}",
|
||||||
)
|
)
|
||||||
raise ApiServerError(
|
raise ApiServerError(
|
||||||
f"The API server at {default_base_url()} is currently unreachable. "
|
f"The API server at {resolved_base_url} is currently unreachable. "
|
||||||
f"The service may be experiencing issues."
|
f"The service may be experiencing issues."
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user