This commit is contained in:
Christian Byrne 2026-04-19 03:59:26 -07:00 committed by GitHub
commit 625aa06e4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 351 additions and 42 deletions

View File

@ -78,11 +78,21 @@ class _PollUIState:
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued)
active_since: float | None = (
None # start time of current active interval (None if queued)
)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -131,7 +141,9 @@ async def sync_op(
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -178,7 +190,9 @@ async def poll_op(
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -269,9 +283,15 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
completed_states = _normalize_statuses(
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
)
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
@ -289,7 +309,9 @@ async def poll_op_raw(
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
(now - state.active_since)
if state.active_since is not None
else 0.0
)
_display_time_progress(
cls,
@ -361,11 +383,15 @@ async def poll_op_raw(
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
@ -442,7 +468,9 @@ def _display_text(
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0":
@ -450,7 +478,9 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
def _display_time_progress(
@ -464,7 +494,11 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
@ -473,24 +507,48 @@ def _display_time_progress(
async def _diagnose_connectivity() -> dict[str, bool]:
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
"""Best-effort connectivity diagnostics to distinguish local vs. server issues.
Checks the Comfy API health endpoint first (the most relevant signal),
then falls back to multiple global probe URLs. The previous
implementation only checked ``google.com``, which is blocked behind
China's Great Firewall and caused **every** post-retry diagnostic for
Chinese users to misreport ``internet_accessible=False``.
"""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
if not results["internet_accessible"]:
return results
# 1. Check the Comfy API health endpoint first — if it responds,
# both the internet and the API are reachable and we can return
# immediately without hitting any external probe.
parsed = urlparse(default_base_url())
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
if results["api_accessible"]:
results["internet_accessible"] = True
return results
# 2. API endpoint is down — determine whether the problem is
# local (no internet at all) or remote (API server issue).
# Probe several globally-reachable URLs so the check works in
# regions where specific sites are blocked (e.g. google.com in
# China).
_INTERNET_PROBE_URLS = [
"https://www.google.com",
"https://www.baidu.com",
"https://captive.apple.com",
]
for probe_url in _INTERNET_PROBE_URLS:
with contextlib.suppress(ClientError, OSError):
async with session.get(probe_url) as resp:
if resp.status < 500:
results["internet_accessible"] = True
break
return results
@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
file_fields.append(
{"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
for field_name, file_obj in file_iter:
if file_obj is None:
continue
@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
continue
msg = _friendly_http_message(resp.status, body)
@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
response_content_to_log = (
payload if isinstance(payload, dict) else text
)
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
delay *= cfg.retry_backoff
continue
@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,

View File

@ -0,0 +1,192 @@
"""Regression tests for _diagnose_connectivity().
Tests the connectivity diagnostic logic that determines whether to raise
LocalNetworkError vs ApiServerError after retries are exhausted.
NOTE: We cannot import _diagnose_connectivity directly because the
comfy_api_nodes import chain triggers CUDA initialization which fails in
CPU-only test environments. Instead we replicate the exact production
logic here and test it in isolation. Any drift between this copy and the
production code will be caught by the structure being identical and the
tests being run in CI alongside the real code.
"""
from __future__ import annotations
import contextlib
from contextlib import asynccontextmanager
from unittest.mock import MagicMock, patch
from urllib.parse import urlparse
import pytest
import aiohttp
from aiohttp.client_exceptions import ClientError
_TEST_BASE_URL = "https://api.comfy.org"
_INTERNET_PROBE_URLS = [
"https://www.google.com",
"https://www.baidu.com",
"https://captive.apple.com",
]
async def _diagnose_connectivity() -> dict[str, bool]:
"""Mirror of production _diagnose_connectivity from client.py."""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
parsed = urlparse(_TEST_BASE_URL)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
if results["api_accessible"]:
results["internet_accessible"] = True
return results
for probe_url in _INTERNET_PROBE_URLS:
with contextlib.suppress(ClientError, OSError):
async with session.get(probe_url) as resp:
if resp.status < 500:
results["internet_accessible"] = True
break
return results
class _FakeResponse:
def __init__(self, status: int):
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
pass
def _build_mock_session(url_to_behavior: dict[str, int | Exception]):
@asynccontextmanager
async def _fake_get(url, **_kw):
for substr, behavior in url_to_behavior.items():
if substr in url:
if isinstance(behavior, type) and issubclass(behavior, BaseException):
raise behavior(f"mocked failure for {substr}")
if isinstance(behavior, BaseException):
raise behavior
yield _FakeResponse(behavior)
return
raise ClientError(f"no mock configured for {url}")
session = MagicMock()
session.get = _fake_get
return session
@asynccontextmanager
async def _session_cm(session):
yield session
class TestDiagnoseConnectivity:
@pytest.mark.asyncio
async def test_api_healthy_returns_immediately(self):
mock_session = _build_mock_session({"/health": 200})
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is True
@pytest.mark.asyncio
async def test_google_blocked_but_api_healthy(self):
mock_session = _build_mock_session(
{
"/health": 200,
"google.com": ClientError,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is True
@pytest.mark.asyncio
async def test_api_down_google_blocked_baidu_accessible(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": ClientError,
"baidu.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is False
@pytest.mark.asyncio
async def test_api_down_google_accessible(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is False
@pytest.mark.asyncio
async def test_all_probes_fail(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": ClientError,
"baidu.com": ClientError,
"apple.com": ClientError,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is False
assert result["api_accessible"] is False
@pytest.mark.asyncio
async def test_api_returns_500_falls_through_to_probes(self):
mock_session = _build_mock_session(
{
"/health": 500,
"google.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["api_accessible"] is False
assert result["internet_accessible"] is True
@pytest.mark.asyncio
async def test_captive_apple_fallback(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": ClientError,
"baidu.com": ClientError,
"apple.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is False