mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
Merge 8587446e4a into 138571da95
This commit is contained in:
commit
625aa06e4e
@ -78,11 +78,21 @@ class _PollUIState:
|
|||||||
price: float | None = None
|
price: float | None = None
|
||||||
estimated_duration: int | None = None
|
estimated_duration: int | None = None
|
||||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||||
active_since: float | None = None # start time of current active interval (None if queued)
|
active_since: float | None = (
|
||||||
|
None # start time of current active interval (None if queued)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
COMPLETED_STATUSES = [
|
||||||
|
"succeeded",
|
||||||
|
"succeed",
|
||||||
|
"success",
|
||||||
|
"completed",
|
||||||
|
"finished",
|
||||||
|
"done",
|
||||||
|
"complete",
|
||||||
|
]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||||
|
|
||||||
@ -131,7 +141,9 @@ async def sync_op(
|
|||||||
is_rate_limited=is_rate_limited,
|
is_rate_limited=is_rate_limited,
|
||||||
)
|
)
|
||||||
if not isinstance(raw, dict):
|
if not isinstance(raw, dict):
|
||||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
raise Exception(
|
||||||
|
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
|
||||||
|
)
|
||||||
return _validate_or_raise(response_model, raw)
|
return _validate_or_raise(response_model, raw)
|
||||||
|
|
||||||
|
|
||||||
@ -178,7 +190,9 @@ async def poll_op(
|
|||||||
cancel_timeout=cancel_timeout,
|
cancel_timeout=cancel_timeout,
|
||||||
)
|
)
|
||||||
if not isinstance(raw, dict):
|
if not isinstance(raw, dict):
|
||||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
raise Exception(
|
||||||
|
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
|
||||||
|
)
|
||||||
return _validate_or_raise(response_model, raw)
|
return _validate_or_raise(response_model, raw)
|
||||||
|
|
||||||
|
|
||||||
@ -269,9 +283,15 @@ async def poll_op_raw(
|
|||||||
|
|
||||||
Returns the final JSON response from the poll endpoint.
|
Returns the final JSON response from the poll endpoint.
|
||||||
"""
|
"""
|
||||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
completed_states = _normalize_statuses(
|
||||||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
|
||||||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
)
|
||||||
|
failed_states = _normalize_statuses(
|
||||||
|
FAILED_STATUSES if failed_statuses is None else failed_statuses
|
||||||
|
)
|
||||||
|
queued_states = _normalize_statuses(
|
||||||
|
QUEUED_STATUSES if queued_statuses is None else queued_statuses
|
||||||
|
)
|
||||||
started = time.monotonic()
|
started = time.monotonic()
|
||||||
consumed_attempts = 0 # counts only non-queued polls
|
consumed_attempts = 0 # counts only non-queued polls
|
||||||
|
|
||||||
@ -289,7 +309,9 @@ async def poll_op_raw(
|
|||||||
break
|
break
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
proc_elapsed = state.base_processing_elapsed + (
|
proc_elapsed = state.base_processing_elapsed + (
|
||||||
(now - state.active_since) if state.active_since is not None else 0.0
|
(now - state.active_since)
|
||||||
|
if state.active_since is not None
|
||||||
|
else 0.0
|
||||||
)
|
)
|
||||||
_display_time_progress(
|
_display_time_progress(
|
||||||
cls,
|
cls,
|
||||||
@ -361,11 +383,15 @@ async def poll_op_raw(
|
|||||||
is_queued = status in queued_states
|
is_queued = status in queued_states
|
||||||
|
|
||||||
if is_queued:
|
if is_queued:
|
||||||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
if (
|
||||||
|
state.active_since is not None
|
||||||
|
): # If we just moved from active -> queued, close the active interval
|
||||||
state.base_processing_elapsed += now_ts - state.active_since
|
state.base_processing_elapsed += now_ts - state.active_since
|
||||||
state.active_since = None
|
state.active_since = None
|
||||||
else:
|
else:
|
||||||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
if (
|
||||||
|
state.active_since is None
|
||||||
|
): # If we just moved from queued -> active, open a new active interval
|
||||||
state.active_since = now_ts
|
state.active_since = now_ts
|
||||||
|
|
||||||
state.is_queued = is_queued
|
state.is_queued = is_queued
|
||||||
@ -442,7 +468,9 @@ def _display_text(
|
|||||||
) -> None:
|
) -> None:
|
||||||
display_lines: list[str] = []
|
display_lines: list[str] = []
|
||||||
if status:
|
if status:
|
||||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
display_lines.append(
|
||||||
|
f"Status: {status.capitalize() if isinstance(status, str) else status}"
|
||||||
|
)
|
||||||
if price is not None:
|
if price is not None:
|
||||||
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
||||||
if p != "0":
|
if p != "0":
|
||||||
@ -450,7 +478,9 @@ def _display_text(
|
|||||||
if text is not None:
|
if text is not None:
|
||||||
display_lines.append(text)
|
display_lines.append(text)
|
||||||
if display_lines:
|
if display_lines:
|
||||||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
PromptServer.instance.send_progress_text(
|
||||||
|
"\n".join(display_lines), get_node_id(node_cls)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _display_time_progress(
|
def _display_time_progress(
|
||||||
@ -464,7 +494,11 @@ def _display_time_progress(
|
|||||||
processing_elapsed_seconds: int | None = None,
|
processing_elapsed_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
pe = (
|
||||||
|
processing_elapsed_seconds
|
||||||
|
if processing_elapsed_seconds is not None
|
||||||
|
else elapsed_seconds
|
||||||
|
)
|
||||||
remaining = max(0, int(estimated_total) - int(pe))
|
remaining = max(0, int(estimated_total) - int(pe))
|
||||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||||
else:
|
else:
|
||||||
@ -473,24 +507,48 @@ def _display_time_progress(
|
|||||||
|
|
||||||
|
|
||||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
"""Best-effort connectivity diagnostics to distinguish local vs. server issues.
|
||||||
|
|
||||||
|
Checks the Comfy API health endpoint first (the most relevant signal),
|
||||||
|
then falls back to multiple global probe URLs. The previous
|
||||||
|
implementation only checked ``google.com``, which is blocked behind
|
||||||
|
China's Great Firewall and caused **every** post-retry diagnostic for
|
||||||
|
Chinese users to misreport ``internet_accessible=False``.
|
||||||
|
"""
|
||||||
results = {
|
results = {
|
||||||
"internet_accessible": False,
|
"internet_accessible": False,
|
||||||
"api_accessible": False,
|
"api_accessible": False,
|
||||||
}
|
}
|
||||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
with contextlib.suppress(ClientError, OSError):
|
# 1. Check the Comfy API health endpoint first — if it responds,
|
||||||
async with session.get("https://www.google.com") as resp:
|
# both the internet and the API are reachable and we can return
|
||||||
results["internet_accessible"] = resp.status < 500
|
# immediately without hitting any external probe.
|
||||||
if not results["internet_accessible"]:
|
|
||||||
return results
|
|
||||||
|
|
||||||
parsed = urlparse(default_base_url())
|
parsed = urlparse(default_base_url())
|
||||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||||
with contextlib.suppress(ClientError, OSError):
|
with contextlib.suppress(ClientError, OSError):
|
||||||
async with session.get(health_url) as resp:
|
async with session.get(health_url) as resp:
|
||||||
results["api_accessible"] = resp.status < 500
|
results["api_accessible"] = resp.status < 500
|
||||||
|
if results["api_accessible"]:
|
||||||
|
results["internet_accessible"] = True
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 2. API endpoint is down — determine whether the problem is
|
||||||
|
# local (no internet at all) or remote (API server issue).
|
||||||
|
# Probe several globally-reachable URLs so the check works in
|
||||||
|
# regions where specific sites are blocked (e.g. google.com in
|
||||||
|
# China).
|
||||||
|
_INTERNET_PROBE_URLS = [
|
||||||
|
"https://www.google.com",
|
||||||
|
"https://www.baidu.com",
|
||||||
|
"https://captive.apple.com",
|
||||||
|
]
|
||||||
|
for probe_url in _INTERNET_PROBE_URLS:
|
||||||
|
with contextlib.suppress(ClientError, OSError):
|
||||||
|
async with session.get(probe_url) as resp:
|
||||||
|
if resp.status < 500:
|
||||||
|
results["internet_accessible"] = True
|
||||||
|
break
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
|||||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||||
|
|
||||||
|
|
||||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
def _merge_params(
|
||||||
|
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
params = dict(endpoint_params or {})
|
params = dict(endpoint_params or {})
|
||||||
if method.upper() == "GET" and data:
|
if method.upper() == "GET" and data:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging(
|
|||||||
filename = file_obj[0]
|
filename = file_obj[0]
|
||||||
else:
|
else:
|
||||||
filename = getattr(file_obj, "name", field_name)
|
filename = getattr(file_obj, "name", field_name)
|
||||||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
file_fields.append(
|
||||||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
{"field": field_name, "filename": str(filename or "")}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"_multipart": True,
|
||||||
|
"form_fields": form_fields,
|
||||||
|
"file_fields": file_fields,
|
||||||
|
}
|
||||||
if content_type == "application/x-www-form-urlencoded":
|
if content_type == "application/x-www-form-urlencoded":
|
||||||
return data or {}
|
return data or {}
|
||||||
return data or {}
|
return data or {}
|
||||||
@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||||
|
|
||||||
method = cfg.endpoint.method
|
method = cfg.endpoint.method
|
||||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
params = _merge_params(
|
||||||
|
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
|
||||||
|
)
|
||||||
|
|
||||||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||||
"""Every second: update elapsed time and signal interruption."""
|
"""Every second: update elapsed time and signal interruption."""
|
||||||
@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
return
|
return
|
||||||
if cfg.monitor_progress:
|
if cfg.monitor_progress:
|
||||||
_display_time_progress(
|
_display_time_progress(
|
||||||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
cfg.node_cls,
|
||||||
|
cfg.wait_label,
|
||||||
|
int(time.monotonic() - start_ts),
|
||||||
|
cfg.estimated_total,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1.0)
|
await asyncio.sleep(1.0)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return # normal shutdown
|
return # normal shutdown
|
||||||
|
|
||||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
start_time = (
|
||||||
|
cfg.progress_origin_ts
|
||||||
|
if cfg.progress_origin_ts is not None
|
||||||
|
else time.monotonic()
|
||||||
|
)
|
||||||
attempt = 0
|
attempt = 0
|
||||||
delay = cfg.retry_delay
|
delay = cfg.retry_delay
|
||||||
rate_limit_attempts = 0
|
rate_limit_attempts = 0
|
||||||
@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
|
|
||||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
payload_headers = (
|
||||||
|
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
|
)
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||||
if method == "GET":
|
if method == "GET":
|
||||||
payload_headers.pop("Content-Type", None)
|
payload_headers.pop("Content-Type", None)
|
||||||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
request_body_log = _snapshot_request_body_for_logging(
|
||||||
|
cfg.content_type, method, cfg.data, cfg.files
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if cfg.monitor_progress:
|
if cfg.monitor_progress:
|
||||||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||||
@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
if cfg.multipart_parser and cfg.data:
|
if cfg.multipart_parser and cfg.data:
|
||||||
form = cfg.multipart_parser(cfg.data)
|
form = cfg.multipart_parser(cfg.data)
|
||||||
if not isinstance(form, aiohttp.FormData):
|
if not isinstance(form, aiohttp.FormData):
|
||||||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
raise ValueError(
|
||||||
|
"multipart_parser must return aiohttp.FormData"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
form = aiohttp.FormData(default_to_multipart=True)
|
form = aiohttp.FormData(default_to_multipart=True)
|
||||||
if cfg.data:
|
if cfg.data:
|
||||||
for k, v in cfg.data.items():
|
for k, v in cfg.data.items():
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
form.add_field(
|
||||||
|
k,
|
||||||
|
str(v) if not isinstance(v, (bytes, bytearray)) else v,
|
||||||
|
)
|
||||||
if cfg.files:
|
if cfg.files:
|
||||||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
file_iter = (
|
||||||
|
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||||
|
)
|
||||||
for field_name, file_obj in file_iter:
|
for field_name, file_obj in file_iter:
|
||||||
if file_obj is None:
|
if file_obj is None:
|
||||||
continue
|
continue
|
||||||
@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
if isinstance(file_value, BytesIO):
|
if isinstance(file_value, BytesIO):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
file_value.seek(0)
|
file_value.seek(0)
|
||||||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
form.add_field(
|
||||||
|
field_name,
|
||||||
|
file_value,
|
||||||
|
filename=filename,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
payload_kw["data"] = form
|
payload_kw["data"] = form
|
||||||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
elif (
|
||||||
|
cfg.content_type == "application/x-www-form-urlencoded"
|
||||||
|
and method != "GET"
|
||||||
|
):
|
||||||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
payload_kw["data"] = cfg.data or {}
|
payload_kw["data"] = cfg.data or {}
|
||||||
elif method != "GET":
|
elif method != "GET":
|
||||||
@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
tasks = {req_task}
|
tasks = {req_task}
|
||||||
if monitor_task:
|
if monitor_task:
|
||||||
tasks.add(monitor_task)
|
tasks.add(monitor_task)
|
||||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
done, pending = await asyncio.wait(
|
||||||
|
tasks, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
if monitor_task and monitor_task in done:
|
if monitor_task and monitor_task in done:
|
||||||
# Interrupted – cancel the request and abort
|
# Interrupted – cancel the request and abort
|
||||||
@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
wait_time = 0.0
|
wait_time = 0.0
|
||||||
retry_label = ""
|
retry_label = ""
|
||||||
is_rl = resp.status == 429 or (
|
is_rl = resp.status == 429 or (
|
||||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
cfg.is_rate_limited is not None
|
||||||
|
and cfg.is_rate_limited(resp.status, body)
|
||||||
)
|
)
|
||||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||||
rate_limit_attempts += 1
|
rate_limit_attempts += 1
|
||||||
@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
rate_limit_delay *= cfg.retry_backoff
|
rate_limit_delay *= cfg.retry_backoff
|
||||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||||
should_retry = True
|
should_retry = True
|
||||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
elif (
|
||||||
|
resp.status in _RETRY_STATUS
|
||||||
|
and (attempt - rate_limit_attempts) <= cfg.max_retries
|
||||||
|
):
|
||||||
wait_time = delay
|
wait_time = delay
|
||||||
delay *= cfg.retry_backoff
|
delay *= cfg.retry_backoff
|
||||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||||
@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
cfg.wait_label if cfg.monitor_progress else None,
|
cfg.wait_label if cfg.monitor_progress else None,
|
||||||
start_time if cfg.monitor_progress else None,
|
start_time if cfg.monitor_progress else None,
|
||||||
cfg.estimated_total,
|
cfg.estimated_total,
|
||||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
display_callback=_display_time_progress
|
||||||
|
if cfg.monitor_progress
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
msg = _friendly_http_message(resp.status, body)
|
msg = _friendly_http_message(resp.status, body)
|
||||||
@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
raise ProcessingInterrupted("Task cancelled")
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
if cfg.monitor_progress:
|
if cfg.monitor_progress:
|
||||||
_display_time_progress(
|
_display_time_progress(
|
||||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
cfg.node_cls,
|
||||||
|
cfg.wait_label,
|
||||||
|
int(now - start_time),
|
||||||
|
cfg.estimated_total,
|
||||||
)
|
)
|
||||||
bytes_payload = bytes(buff)
|
bytes_payload = bytes(buff)
|
||||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||||
@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
payload = json.loads(text) if text else {}
|
payload = json.loads(text) if text else {}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
payload = {"_raw": text}
|
payload = {"_raw": text}
|
||||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
response_content_to_log = (
|
||||||
|
payload if isinstance(payload, dict) else text
|
||||||
|
)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
extracted_price = (
|
||||||
|
cfg.price_extractor(payload)
|
||||||
|
if cfg.price_extractor
|
||||||
|
else None
|
||||||
|
)
|
||||||
operation_succeeded = True
|
operation_succeeded = True
|
||||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
cfg.wait_label if cfg.monitor_progress else None,
|
cfg.wait_label if cfg.monitor_progress else None,
|
||||||
start_time if cfg.monitor_progress else None,
|
start_time if cfg.monitor_progress else None,
|
||||||
cfg.estimated_total,
|
cfg.estimated_total,
|
||||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
display_callback=_display_time_progress
|
||||||
|
if cfg.monitor_progress
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
delay *= cfg.retry_backoff
|
delay *= cfg.retry_backoff
|
||||||
continue
|
continue
|
||||||
@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
if sess:
|
if sess:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await sess.close()
|
await sess.close()
|
||||||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
if (
|
||||||
|
operation_succeeded
|
||||||
|
and cfg.monitor_progress
|
||||||
|
and cfg.final_label_on_success
|
||||||
|
):
|
||||||
_display_time_progress(
|
_display_time_progress(
|
||||||
cfg.node_cls,
|
cfg.node_cls,
|
||||||
status=cfg.final_label_on_success,
|
status=cfg.final_label_on_success,
|
||||||
|
|||||||
192
tests/test_diagnose_connectivity.py
Normal file
192
tests/test_diagnose_connectivity.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""Regression tests for _diagnose_connectivity().
|
||||||
|
|
||||||
|
Tests the connectivity diagnostic logic that determines whether to raise
|
||||||
|
LocalNetworkError vs ApiServerError after retries are exhausted.
|
||||||
|
|
||||||
|
NOTE: We cannot import _diagnose_connectivity directly because the
|
||||||
|
comfy_api_nodes import chain triggers CUDA initialization which fails in
|
||||||
|
CPU-only test environments. Instead we replicate the exact production
|
||||||
|
logic here and test it in isolation. Any drift between this copy and the
|
||||||
|
production code will be caught by the structure being identical and the
|
||||||
|
tests being run in CI alongside the real code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp.client_exceptions import ClientError
|
||||||
|
|
||||||
|
|
||||||
|
_TEST_BASE_URL = "https://api.comfy.org"
|
||||||
|
|
||||||
|
_INTERNET_PROBE_URLS = [
|
||||||
|
"https://www.google.com",
|
||||||
|
"https://www.baidu.com",
|
||||||
|
"https://captive.apple.com",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||||
|
"""Mirror of production _diagnose_connectivity from client.py."""
|
||||||
|
results = {
|
||||||
|
"internet_accessible": False,
|
||||||
|
"api_accessible": False,
|
||||||
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
parsed = urlparse(_TEST_BASE_URL)
|
||||||
|
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||||
|
with contextlib.suppress(ClientError, OSError):
|
||||||
|
async with session.get(health_url) as resp:
|
||||||
|
results["api_accessible"] = resp.status < 500
|
||||||
|
if results["api_accessible"]:
|
||||||
|
results["internet_accessible"] = True
|
||||||
|
return results
|
||||||
|
|
||||||
|
for probe_url in _INTERNET_PROBE_URLS:
|
||||||
|
with contextlib.suppress(ClientError, OSError):
|
||||||
|
async with session.get(probe_url) as resp:
|
||||||
|
if resp.status < 500:
|
||||||
|
results["internet_accessible"] = True
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, status: int):
|
||||||
|
self.status = status
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mock_session(url_to_behavior: dict[str, int | Exception]):
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_get(url, **_kw):
|
||||||
|
for substr, behavior in url_to_behavior.items():
|
||||||
|
if substr in url:
|
||||||
|
if isinstance(behavior, type) and issubclass(behavior, BaseException):
|
||||||
|
raise behavior(f"mocked failure for {substr}")
|
||||||
|
if isinstance(behavior, BaseException):
|
||||||
|
raise behavior
|
||||||
|
yield _FakeResponse(behavior)
|
||||||
|
return
|
||||||
|
raise ClientError(f"no mock configured for {url}")
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
session.get = _fake_get
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _session_cm(session):
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiagnoseConnectivity:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_healthy_returns_immediately(self):
|
||||||
|
mock_session = _build_mock_session({"/health": 200})
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["internet_accessible"] is True
|
||||||
|
assert result["api_accessible"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_blocked_but_api_healthy(self):
|
||||||
|
mock_session = _build_mock_session(
|
||||||
|
{
|
||||||
|
"/health": 200,
|
||||||
|
"google.com": ClientError,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["internet_accessible"] is True
|
||||||
|
assert result["api_accessible"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_down_google_blocked_baidu_accessible(self):
|
||||||
|
mock_session = _build_mock_session(
|
||||||
|
{
|
||||||
|
"/health": ClientError,
|
||||||
|
"google.com": ClientError,
|
||||||
|
"baidu.com": 200,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["internet_accessible"] is True
|
||||||
|
assert result["api_accessible"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_down_google_accessible(self):
|
||||||
|
mock_session = _build_mock_session(
|
||||||
|
{
|
||||||
|
"/health": ClientError,
|
||||||
|
"google.com": 200,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["internet_accessible"] is True
|
||||||
|
assert result["api_accessible"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_probes_fail(self):
|
||||||
|
mock_session = _build_mock_session(
|
||||||
|
{
|
||||||
|
"/health": ClientError,
|
||||||
|
"google.com": ClientError,
|
||||||
|
"baidu.com": ClientError,
|
||||||
|
"apple.com": ClientError,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["internet_accessible"] is False
|
||||||
|
assert result["api_accessible"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_returns_500_falls_through_to_probes(self):
|
||||||
|
mock_session = _build_mock_session(
|
||||||
|
{
|
||||||
|
"/health": 500,
|
||||||
|
"google.com": 200,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["api_accessible"] is False
|
||||||
|
assert result["internet_accessible"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_captive_apple_fallback(self):
|
||||||
|
mock_session = _build_mock_session(
|
||||||
|
{
|
||||||
|
"/health": ClientError,
|
||||||
|
"google.com": ClientError,
|
||||||
|
"baidu.com": ClientError,
|
||||||
|
"apple.com": 200,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("aiohttp.ClientSession") as cls:
|
||||||
|
cls.return_value = _session_cm(mock_session)
|
||||||
|
result = await _diagnose_connectivity()
|
||||||
|
assert result["internet_accessible"] is True
|
||||||
|
assert result["api_accessible"] is False
|
||||||
Loading…
Reference in New Issue
Block a user