From 4f101765a3b1c3467cc6f9e83260a10f942b25c1 Mon Sep 17 00:00:00 2001 From: Glary-Bot Date: Sun, 19 Apr 2026 10:42:56 +0000 Subject: [PATCH] fix: use multi-endpoint connectivity check to support China/GFW users The _diagnose_connectivity() function previously only probed google.com to determine whether the user has internet access. Since google.com is blocked by China's Great Firewall, Chinese users were always misdiagnosed as having no internet, causing misleading LocalNetworkError messages. Now checks the Comfy API health endpoint first (the most relevant signal), then falls back to multiple probe URLs (google.com, baidu.com, captive.apple.com) to support users in regions where specific sites are blocked. --- comfy_api_nodes/util/client.py | 201 ++++++++++++++++++++++++++------- 1 file changed, 159 insertions(+), 42 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a..d55e63610 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -78,11 +78,21 @@ class _PollUIState: price: float | None = None estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: float | None = None # start time of current active interval (None if queued) + active_since: float | None = ( + None # start time of current active interval (None if queued) + ) _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] +COMPLETED_STATUSES = [ + "succeeded", + "succeed", + "success", + "completed", + "finished", + "done", + "complete", +] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] @@ -131,7 +141,9 @@ async def sync_op( is_rate_limited=is_rate_limited, ) if not isinstance(raw, dict): - raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + raise Exception( + "Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)." + ) return _validate_or_raise(response_model, raw) @@ -178,7 +190,9 @@ async def poll_op( cancel_timeout=cancel_timeout, ) if not isinstance(raw, dict): - raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + raise Exception( + "Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)." + ) return _validate_or_raise(response_model, raw) @@ -269,9 +283,15 @@ async def poll_op_raw( Returns the final JSON response from the poll endpoint. """ - completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) - failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) - queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + completed_states = _normalize_statuses( + COMPLETED_STATUSES if completed_statuses is None else completed_statuses + ) + failed_states = _normalize_statuses( + FAILED_STATUSES if failed_statuses is None else failed_statuses + ) + queued_states = _normalize_statuses( + QUEUED_STATUSES if queued_statuses is None else queued_statuses + ) started = time.monotonic() consumed_attempts = 0 # counts only non-queued polls @@ -289,7 +309,9 @@ async def poll_op_raw( break now = time.monotonic() proc_elapsed = state.base_processing_elapsed + ( - (now - state.active_since) if state.active_since is not None else 0.0 + (now - state.active_since) + if state.active_since is not None + else 0.0 ) _display_time_progress( cls, @@ -361,11 +383,15 @@ async def poll_op_raw( is_queued = status in queued_states if is_queued: - if state.active_since is not None: # If we just moved from active -> queued, close the active interval + if ( + state.active_since is not None + ): # If we just moved from active -> queued, close the active interval state.base_processing_elapsed += now_ts - state.active_since state.active_since = None else: - if state.active_since is None: # If we just moved from queued -> active, open a new active interval + if ( + state.active_since is None + ): # If we just moved from queued -> active, open a new active interval state.active_since = now_ts state.is_queued = is_queued @@ -442,7 +468,9 @@ def _display_text( ) -> None: display_lines: list[str] = [] if status: - display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + display_lines.append( + f"Status: {status.capitalize() if isinstance(status, str) else status}" + ) if price is not None: p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": @@ -450,7 +478,9 @@ def _display_text( if text is not None: display_lines.append(text) if display_lines: - PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + PromptServer.instance.send_progress_text( + "\n".join(display_lines), get_node_id(node_cls) + ) def _display_time_progress( @@ -464,7 +494,11 @@ def _display_time_progress( processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: - pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + pe = ( + processing_elapsed_seconds + if processing_elapsed_seconds is not None + else elapsed_seconds + ) remaining = max(0, int(estimated_total) - int(pe)) time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: @@ -473,24 +507,48 @@ def _display_time_progress( async def _diagnose_connectivity() -> dict[str, bool]: - """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + """Best-effort connectivity diagnostics to distinguish local vs. server issues. + + Checks the Comfy API health endpoint first (the most relevant signal), + then falls back to multiple global probe URLs. The previous + implementation only checked ``google.com``, which is blocked behind + China's Great Firewall and caused **every** post-retry diagnostic for + Chinese users to misreport ``internet_accessible=False``. + """ results = { "internet_accessible": False, "api_accessible": False, } timeout = aiohttp.ClientTimeout(total=5.0) async with aiohttp.ClientSession(timeout=timeout) as session: - with contextlib.suppress(ClientError, OSError): - async with session.get("https://www.google.com") as resp: - results["internet_accessible"] = resp.status < 500 - if not results["internet_accessible"]: - return results - + # 1. Check the Comfy API health endpoint first — if it responds, + # both the internet and the API are reachable and we can return + # immediately without hitting any external probe. parsed = urlparse(default_base_url()) health_url = f"{parsed.scheme}://{parsed.netloc}/health" with contextlib.suppress(ClientError, OSError): async with session.get(health_url) as resp: results["api_accessible"] = resp.status < 500 + if results["api_accessible"]: + results["internet_accessible"] = True + return results + + # 2. API endpoint is down — determine whether the problem is + # local (no internet at all) or remote (API server issue). + # Probe several globally-reachable URLs so the check works in + # regions where specific sites are blocked (e.g. google.com in + # China). + _INTERNET_PROBE_URLS = [ + "https://www.google.com", + "https://www.baidu.com", + "https://captive.apple.com", + ] + for probe_url in _INTERNET_PROBE_URLS: + with contextlib.suppress(ClientError, OSError): + async with session.get(probe_url) as resp: + if resp.status < 500: + results["internet_accessible"] = True + break return results @@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: +def _merge_params( + endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None +) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging( filename = file_obj[0] else: filename = getattr(file_obj, "name", field_name) - file_fields.append({"field": field_name, "filename": str(filename or "")}) - return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + file_fields.append( + {"field": field_name, "filename": str(filename or "")} + ) + return { + "_multipart": True, + "form_fields": form_fields, + "file_fields": file_fields, + } if content_type == "application/x-www-form-urlencoded": return data or {} return data or {} @@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) method = cfg.endpoint.method - params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + params = _merge_params( + cfg.endpoint.query_params, method, cfg.data if method == "GET" else None + ) async def _monitor(stop_evt: asyncio.Event, start_ts: float): """Every second: update elapsed time and signal interruption.""" @@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): return if cfg.monitor_progress: _display_time_progress( - cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + cfg.node_cls, + cfg.wait_label, + int(time.monotonic() - start_ts), + cfg.estimated_total, ) await asyncio.sleep(1.0) except asyncio.CancelledError: return # normal shutdown - start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + start_time = ( + cfg.progress_origin_ts + if cfg.progress_origin_ts is not None + else time.monotonic() + ) attempt = 0 delay = cfg.retry_delay rate_limit_attempts = 0 @@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) - payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + payload_headers = ( + {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + ) if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) if cfg.endpoint.headers: @@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_kw: dict[str, Any] = {"headers": payload_headers} if method == "GET": payload_headers.pop("Content-Type", None) - request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + request_body_log = _snapshot_request_body_for_logging( + cfg.content_type, method, cfg.data, cfg.files + ) try: if cfg.monitor_progress: monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) @@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if cfg.multipart_parser and cfg.data: form = cfg.multipart_parser(cfg.data) if not isinstance(form, aiohttp.FormData): - raise ValueError("multipart_parser must return aiohttp.FormData") + raise ValueError( + "multipart_parser must return aiohttp.FormData" + ) else: form = aiohttp.FormData(default_to_multipart=True) if cfg.data: for k, v in cfg.data.items(): if v is None: continue - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + form.add_field( + k, + str(v) if not isinstance(v, (bytes, bytearray)) else v, + ) if cfg.files: - file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + file_iter = ( + cfg.files if isinstance(cfg.files, list) else cfg.files.items() + ) for field_name, file_obj in file_iter: if file_obj is None: continue @@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if isinstance(file_value, BytesIO): with contextlib.suppress(Exception): file_value.seek(0) - form.add_field(field_name, file_value, filename=filename, content_type=content_type) + form.add_field( + field_name, + file_value, + filename=filename, + content_type=content_type, + ) payload_kw["data"] = form - elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + elif ( + cfg.content_type == "application/x-www-form-urlencoded" + and method != "GET" + ): payload_headers["Content-Type"] = "application/x-www-form-urlencoded" payload_kw["data"] = cfg.data or {} elif method != "GET": @@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): tasks = {req_task} if monitor_task: tasks.add(monitor_task) - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) if monitor_task and monitor_task in done: # Interrupted – cancel the request and abort @@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): wait_time = 0.0 retry_label = "" is_rl = resp.status == 429 or ( - cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body) + cfg.is_rate_limited is not None + and cfg.is_rate_limited(resp.status, body) ) if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit: rate_limit_attempts += 1 @@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): rate_limit_delay *= cfg.retry_backoff retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}" should_retry = True - elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries: + elif ( + resp.status in _RETRY_STATUS + and (attempt - rate_limit_attempts) <= cfg.max_retries + ): wait_time = delay delay *= cfg.retry_backoff retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}" @@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.wait_label if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None, cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, + display_callback=_display_time_progress + if cfg.monitor_progress + else None, ) continue msg = _friendly_http_message(resp.status, body) @@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): raise ProcessingInterrupted("Task cancelled") if cfg.monitor_progress: _display_time_progress( - cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + cfg.node_cls, + cfg.wait_label, + int(now - start_time), + cfg.estimated_total, ) bytes_payload = bytes(buff) resp_headers = {k.lower(): v for k, v in resp.headers.items()} @@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload = json.loads(text) if text else {} except json.JSONDecodeError: payload = {"_raw": text} - response_content_to_log = payload if isinstance(payload, dict) else text + response_content_to_log = ( + payload if isinstance(payload, dict) else text + ) with contextlib.suppress(Exception): - extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None + extracted_price = ( + cfg.price_extractor(payload) + if cfg.price_extractor + else None + ) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.wait_label if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None, cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, + display_callback=_display_time_progress + if cfg.monitor_progress + else None, ) delay *= cfg.retry_backoff continue @@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if sess: with contextlib.suppress(Exception): await sess.close() - if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + if ( + operation_succeeded + and cfg.monitor_progress + and cfg.final_label_on_success + ): _display_time_progress( cfg.node_cls, status=cfg.final_label_on_success,