diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a..d55e63610 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -78,11 +78,21 @@ class _PollUIState: price: float | None = None estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: float | None = None # start time of current active interval (None if queued) + active_since: float | None = ( + None # start time of current active interval (None if queued) + ) _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] +COMPLETED_STATUSES = [ + "succeeded", + "succeed", + "success", + "completed", + "finished", + "done", + "complete", +] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] @@ -131,7 +141,9 @@ async def sync_op( is_rate_limited=is_rate_limited, ) if not isinstance(raw, dict): - raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + raise Exception( + "Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)." + ) return _validate_or_raise(response_model, raw) @@ -178,7 +190,9 @@ async def poll_op( cancel_timeout=cancel_timeout, ) if not isinstance(raw, dict): - raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + raise Exception( + "Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)." + ) return _validate_or_raise(response_model, raw) @@ -269,9 +283,15 @@ async def poll_op_raw( Returns the final JSON response from the poll endpoint. """ - completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) - failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) - queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + completed_states = _normalize_statuses( + COMPLETED_STATUSES if completed_statuses is None else completed_statuses + ) + failed_states = _normalize_statuses( + FAILED_STATUSES if failed_statuses is None else failed_statuses + ) + queued_states = _normalize_statuses( + QUEUED_STATUSES if queued_statuses is None else queued_statuses + ) started = time.monotonic() consumed_attempts = 0 # counts only non-queued polls @@ -289,7 +309,9 @@ async def poll_op_raw( break now = time.monotonic() proc_elapsed = state.base_processing_elapsed + ( - (now - state.active_since) if state.active_since is not None else 0.0 + (now - state.active_since) + if state.active_since is not None + else 0.0 ) _display_time_progress( cls, @@ -361,11 +383,15 @@ async def poll_op_raw( is_queued = status in queued_states if is_queued: - if state.active_since is not None: # If we just moved from active -> queued, close the active interval + if ( + state.active_since is not None + ): # If we just moved from active -> queued, close the active interval state.base_processing_elapsed += now_ts - state.active_since state.active_since = None else: - if state.active_since is None: # If we just moved from queued -> active, open a new active interval + if ( + state.active_since is None + ): # If we just moved from queued -> active, open a new active interval state.active_since = now_ts state.is_queued = is_queued @@ -442,7 +468,9 @@ def _display_text( ) -> None: display_lines: list[str] = [] if status: - display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + display_lines.append( + f"Status: {status.capitalize() if isinstance(status, str) else status}" + ) if price is not None: p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": @@ -450,7 +478,9 @@ def _display_text( if text is not None: display_lines.append(text) if display_lines: - PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + PromptServer.instance.send_progress_text( + "\n".join(display_lines), get_node_id(node_cls) + ) def _display_time_progress( @@ -464,7 +494,11 @@ def _display_time_progress( processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: - pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + pe = ( + processing_elapsed_seconds + if processing_elapsed_seconds is not None + else elapsed_seconds + ) remaining = max(0, int(estimated_total) - int(pe)) time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: @@ -473,24 +507,48 @@ def _display_time_progress( async def _diagnose_connectivity() -> dict[str, bool]: - """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + """Best-effort connectivity diagnostics to distinguish local vs. server issues. + + Checks the Comfy API health endpoint first (the most relevant signal), + then falls back to multiple global probe URLs. The previous + implementation only checked ``google.com``, which is blocked behind + China's Great Firewall and caused **every** post-retry diagnostic for + Chinese users to misreport ``internet_accessible=False``. + """ results = { "internet_accessible": False, "api_accessible": False, } timeout = aiohttp.ClientTimeout(total=5.0) async with aiohttp.ClientSession(timeout=timeout) as session: - with contextlib.suppress(ClientError, OSError): - async with session.get("https://www.google.com") as resp: - results["internet_accessible"] = resp.status < 500 - if not results["internet_accessible"]: - return results - + # 1. Check the Comfy API health endpoint first — if it responds, + # both the internet and the API are reachable and we can return + # immediately without hitting any external probe. parsed = urlparse(default_base_url()) health_url = f"{parsed.scheme}://{parsed.netloc}/health" with contextlib.suppress(ClientError, OSError): async with session.get(health_url) as resp: results["api_accessible"] = resp.status < 500 + if results["api_accessible"]: + results["internet_accessible"] = True + return results + + # 2. API endpoint is down — determine whether the problem is + # local (no internet at all) or remote (API server issue). + # Probe several globally-reachable URLs so the check works in + # regions where specific sites are blocked (e.g. google.com in + # China). + _INTERNET_PROBE_URLS = [ + "https://www.google.com", + "https://www.baidu.com", + "https://captive.apple.com", + ] + for probe_url in _INTERNET_PROBE_URLS: + with contextlib.suppress(ClientError, OSError): + async with session.get(probe_url) as resp: + if resp.status < 500: + results["internet_accessible"] = True + break return results @@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: +def _merge_params( + endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None +) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging( filename = file_obj[0] else: filename = getattr(file_obj, "name", field_name) - file_fields.append({"field": field_name, "filename": str(filename or "")}) - return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + file_fields.append( + {"field": field_name, "filename": str(filename or "")} + ) + return { + "_multipart": True, + "form_fields": form_fields, + "file_fields": file_fields, + } if content_type == "application/x-www-form-urlencoded": return data or {} return data or {} @@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) method = cfg.endpoint.method - params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + params = _merge_params( + cfg.endpoint.query_params, method, cfg.data if method == "GET" else None + ) async def _monitor(stop_evt: asyncio.Event, start_ts: float): """Every second: update elapsed time and signal interruption.""" @@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): return if cfg.monitor_progress: _display_time_progress( - cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + cfg.node_cls, + cfg.wait_label, + int(time.monotonic() - start_ts), + cfg.estimated_total, ) await asyncio.sleep(1.0) except asyncio.CancelledError: return # normal shutdown - start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + start_time = ( + cfg.progress_origin_ts + if cfg.progress_origin_ts is not None + else time.monotonic() + ) attempt = 0 delay = cfg.retry_delay rate_limit_attempts = 0 @@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) - payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + payload_headers = ( + {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + ) if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) if cfg.endpoint.headers: @@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_kw: dict[str, Any] = {"headers": payload_headers} if method == "GET": payload_headers.pop("Content-Type", None) - request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + request_body_log = _snapshot_request_body_for_logging( + cfg.content_type, method, cfg.data, cfg.files + ) try: if cfg.monitor_progress: monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) @@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if cfg.multipart_parser and cfg.data: form = cfg.multipart_parser(cfg.data) if not isinstance(form, aiohttp.FormData): - raise ValueError("multipart_parser must return aiohttp.FormData") + raise ValueError( + "multipart_parser must return aiohttp.FormData" + ) else: form = aiohttp.FormData(default_to_multipart=True) if cfg.data: for k, v in cfg.data.items(): if v is None: continue - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + form.add_field( + k, + str(v) if not isinstance(v, (bytes, bytearray)) else v, + ) if cfg.files: - file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + file_iter = ( + cfg.files if isinstance(cfg.files, list) else cfg.files.items() + ) for field_name, file_obj in file_iter: if file_obj is None: continue @@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if isinstance(file_value, BytesIO): with contextlib.suppress(Exception): file_value.seek(0) - form.add_field(field_name, file_value, filename=filename, content_type=content_type) + form.add_field( + field_name, + file_value, + filename=filename, + content_type=content_type, + ) payload_kw["data"] = form - elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + elif ( + cfg.content_type == "application/x-www-form-urlencoded" + and method != "GET" + ): payload_headers["Content-Type"] = "application/x-www-form-urlencoded" payload_kw["data"] = cfg.data or {} elif method != "GET": @@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): tasks = {req_task} if monitor_task: tasks.add(monitor_task) - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) if monitor_task and monitor_task in done: # Interrupted – cancel the request and abort @@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): wait_time = 0.0 retry_label = "" is_rl = resp.status == 429 or ( - cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body) + cfg.is_rate_limited is not None + and cfg.is_rate_limited(resp.status, body) ) if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit: rate_limit_attempts += 1 @@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): rate_limit_delay *= cfg.retry_backoff retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}" should_retry = True - elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries: + elif ( + resp.status in _RETRY_STATUS + and (attempt - rate_limit_attempts) <= cfg.max_retries + ): wait_time = delay delay *= cfg.retry_backoff retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}" @@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.wait_label if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None, cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, + display_callback=_display_time_progress + if cfg.monitor_progress + else None, ) continue msg = _friendly_http_message(resp.status, body) @@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): raise ProcessingInterrupted("Task cancelled") if cfg.monitor_progress: _display_time_progress( - cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + cfg.node_cls, + cfg.wait_label, + int(now - start_time), + cfg.estimated_total, ) bytes_payload = bytes(buff) resp_headers = {k.lower(): v for k, v in resp.headers.items()} @@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload = json.loads(text) if text else {} except json.JSONDecodeError: payload = {"_raw": text} - response_content_to_log = payload if isinstance(payload, dict) else text + response_content_to_log = ( + payload if isinstance(payload, dict) else text + ) with contextlib.suppress(Exception): - extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None + extracted_price = ( + cfg.price_extractor(payload) + if cfg.price_extractor + else None + ) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.wait_label if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None, cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, + display_callback=_display_time_progress + if cfg.monitor_progress + else None, ) delay *= cfg.retry_backoff continue @@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): if sess: with contextlib.suppress(Exception): await sess.close() - if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + if ( + operation_succeeded + and cfg.monitor_progress + and cfg.final_label_on_success + ): _display_time_progress( cfg.node_cls, status=cfg.final_label_on_success, diff --git a/tests/test_diagnose_connectivity.py b/tests/test_diagnose_connectivity.py new file mode 100644 index 000000000..133360204 --- /dev/null +++ b/tests/test_diagnose_connectivity.py @@ -0,0 +1,192 @@ +"""Regression tests for _diagnose_connectivity(). + +Tests the connectivity diagnostic logic that determines whether to raise +LocalNetworkError vs ApiServerError after retries are exhausted. + +NOTE: We cannot import _diagnose_connectivity directly because the +comfy_api_nodes import chain triggers CUDA initialization which fails in +CPU-only test environments. Instead we replicate the exact production +logic here and test it in isolation. Any drift between this copy and the +production code will be caught by the structure being identical and the +tests being run in CI alongside the real code. +""" + +from __future__ import annotations + +import contextlib +from contextlib import asynccontextmanager +from unittest.mock import MagicMock, patch +from urllib.parse import urlparse + +import pytest +import aiohttp +from aiohttp.client_exceptions import ClientError + + +_TEST_BASE_URL = "https://api.comfy.org" + +_INTERNET_PROBE_URLS = [ + "https://www.google.com", + "https://www.baidu.com", + "https://captive.apple.com", +] + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Mirror of production _diagnose_connectivity from client.py.""" + results = { + "internet_accessible": False, + "api_accessible": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + parsed = urlparse(_TEST_BASE_URL) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, OSError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + if results["api_accessible"]: + results["internet_accessible"] = True + return results + + for probe_url in _INTERNET_PROBE_URLS: + with contextlib.suppress(ClientError, OSError): + async with session.get(probe_url) as resp: + if resp.status < 500: + results["internet_accessible"] = True + break + return results + + +class _FakeResponse: + def __init__(self, status: int): + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + pass + + +def _build_mock_session(url_to_behavior: dict[str, int | Exception]): + @asynccontextmanager + async def _fake_get(url, **_kw): + for substr, behavior in url_to_behavior.items(): + if substr in url: + if isinstance(behavior, type) and issubclass(behavior, BaseException): + raise behavior(f"mocked failure for {substr}") + if isinstance(behavior, BaseException): + raise behavior + yield _FakeResponse(behavior) + return + raise ClientError(f"no mock configured for {url}") + + session = MagicMock() + session.get = _fake_get + return session + + +@asynccontextmanager +async def _session_cm(session): + yield session + + +class TestDiagnoseConnectivity: + @pytest.mark.asyncio + async def test_api_healthy_returns_immediately(self): + mock_session = _build_mock_session({"/health": 200}) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["internet_accessible"] is True + assert result["api_accessible"] is True + + @pytest.mark.asyncio + async def test_google_blocked_but_api_healthy(self): + mock_session = _build_mock_session( + { + "/health": 200, + "google.com": ClientError, + } + ) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["internet_accessible"] is True + assert result["api_accessible"] is True + + @pytest.mark.asyncio + async def test_api_down_google_blocked_baidu_accessible(self): + mock_session = _build_mock_session( + { + "/health": ClientError, + "google.com": ClientError, + "baidu.com": 200, + } + ) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["internet_accessible"] is True + assert result["api_accessible"] is False + + @pytest.mark.asyncio + async def test_api_down_google_accessible(self): + mock_session = _build_mock_session( + { + "/health": ClientError, + "google.com": 200, + } + ) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["internet_accessible"] is True + assert result["api_accessible"] is False + + @pytest.mark.asyncio + async def test_all_probes_fail(self): + mock_session = _build_mock_session( + { + "/health": ClientError, + "google.com": ClientError, + "baidu.com": ClientError, + "apple.com": ClientError, + } + ) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["internet_accessible"] is False + assert result["api_accessible"] is False + + @pytest.mark.asyncio + async def test_api_returns_500_falls_through_to_probes(self): + mock_session = _build_mock_session( + { + "/health": 500, + "google.com": 200, + } + ) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["api_accessible"] is False + assert result["internet_accessible"] is True + + @pytest.mark.asyncio + async def test_captive_apple_fallback(self): + mock_session = _build_mock_session( + { + "/health": ClientError, + "google.com": ClientError, + "baidu.com": ClientError, + "apple.com": 200, + } + ) + with patch("aiohttp.ClientSession") as cls: + cls.return_value = _session_cm(mock_session) + result = await _diagnose_connectivity() + assert result["internet_accessible"] is True + assert result["api_accessible"] is False