fix: use multi-endpoint connectivity check to support China/GFW users

The _diagnose_connectivity() function previously only probed google.com
to determine whether the user has internet access. Since google.com is
blocked by China's Great Firewall, Chinese users were always misdiagnosed
as having no internet, causing misleading LocalNetworkError messages.

Now checks the Comfy API health endpoint first (the most relevant
signal), then falls back to multiple probe URLs (google.com, baidu.com,
captive.apple.com) to support users in regions where specific sites are
blocked.
This commit is contained in:
Glary-Bot 2026-04-19 10:42:56 +00:00
parent 138571da95
commit 4f101765a3

View File

@ -78,11 +78,21 @@ class _PollUIState:
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued)
active_since: float | None = (
None # start time of current active interval (None if queued)
)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -131,7 +141,9 @@ async def sync_op(
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -178,7 +190,9 @@ async def poll_op(
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -269,9 +283,15 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
completed_states = _normalize_statuses(
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
)
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
@ -289,7 +309,9 @@ async def poll_op_raw(
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
(now - state.active_since)
if state.active_since is not None
else 0.0
)
_display_time_progress(
cls,
@ -361,11 +383,15 @@ async def poll_op_raw(
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
@ -442,7 +468,9 @@ def _display_text(
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0":
@ -450,7 +478,9 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
def _display_time_progress(
@ -464,7 +494,11 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
@ -473,24 +507,48 @@ def _display_time_progress(
async def _diagnose_connectivity() -> dict[str, bool]:
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
"""Best-effort connectivity diagnostics to distinguish local vs. server issues.
Checks the Comfy API health endpoint first (the most relevant signal),
then falls back to multiple global probe URLs. The previous
implementation only checked ``google.com``, which is blocked behind
China's Great Firewall and caused **every** post-retry diagnostic for
Chinese users to misreport ``internet_accessible=False``.
"""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
if not results["internet_accessible"]:
return results
# 1. Check the Comfy API health endpoint first — if it responds,
# both the internet and the API are reachable and we can return
# immediately without hitting any external probe.
parsed = urlparse(default_base_url())
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
if results["api_accessible"]:
results["internet_accessible"] = True
return results
# 2. API endpoint is down — determine whether the problem is
# local (no internet at all) or remote (API server issue).
# Probe several globally-reachable URLs so the check works in
# regions where specific sites are blocked (e.g. google.com in
# China).
_INTERNET_PROBE_URLS = [
"https://www.google.com",
"https://www.baidu.com",
"https://captive.apple.com",
]
for probe_url in _INTERNET_PROBE_URLS:
with contextlib.suppress(ClientError, OSError):
async with session.get(probe_url) as resp:
if resp.status < 500:
results["internet_accessible"] = True
break
return results
@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
file_fields.append(
{"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
for field_name, file_obj in file_iter:
if file_obj is None:
continue
@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
continue
msg = _friendly_http_message(resp.status, body)
@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
response_content_to_log = (
payload if isinstance(payload, dict) else text
)
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
delay *= cfg.retry_backoff
continue
@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,