mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
Merge 8587446e4a into 138571da95
This commit is contained in:
commit
625aa06e4e
@ -78,11 +78,21 @@ class _PollUIState:
|
||||
price: float | None = None
|
||||
estimated_duration: int | None = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
active_since: float | None = (
|
||||
None # start time of current active interval (None if queued)
|
||||
)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
COMPLETED_STATUSES = [
|
||||
"succeeded",
|
||||
"succeed",
|
||||
"success",
|
||||
"completed",
|
||||
"finished",
|
||||
"done",
|
||||
"complete",
|
||||
]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||
|
||||
@ -131,7 +141,9 @@ async def sync_op(
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
raise Exception(
|
||||
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
|
||||
)
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
@ -178,7 +190,9 @@ async def poll_op(
|
||||
cancel_timeout=cancel_timeout,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
raise Exception(
|
||||
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
|
||||
)
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
@ -269,9 +283,15 @@ async def poll_op_raw(
|
||||
|
||||
Returns the final JSON response from the poll endpoint.
|
||||
"""
|
||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||||
completed_states = _normalize_statuses(
|
||||
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
|
||||
)
|
||||
failed_states = _normalize_statuses(
|
||||
FAILED_STATUSES if failed_statuses is None else failed_statuses
|
||||
)
|
||||
queued_states = _normalize_statuses(
|
||||
QUEUED_STATUSES if queued_statuses is None else queued_statuses
|
||||
)
|
||||
started = time.monotonic()
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
@ -289,7 +309,9 @@ async def poll_op_raw(
|
||||
break
|
||||
now = time.monotonic()
|
||||
proc_elapsed = state.base_processing_elapsed + (
|
||||
(now - state.active_since) if state.active_since is not None else 0.0
|
||||
(now - state.active_since)
|
||||
if state.active_since is not None
|
||||
else 0.0
|
||||
)
|
||||
_display_time_progress(
|
||||
cls,
|
||||
@ -361,11 +383,15 @@ async def poll_op_raw(
|
||||
is_queued = status in queued_states
|
||||
|
||||
if is_queued:
|
||||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||||
if (
|
||||
state.active_since is not None
|
||||
): # If we just moved from active -> queued, close the active interval
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
else:
|
||||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||||
if (
|
||||
state.active_since is None
|
||||
): # If we just moved from queued -> active, open a new active interval
|
||||
state.active_since = now_ts
|
||||
|
||||
state.is_queued = is_queued
|
||||
@ -442,7 +468,9 @@ def _display_text(
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
display_lines.append(
|
||||
f"Status: {status.capitalize() if isinstance(status, str) else status}"
|
||||
)
|
||||
if price is not None:
|
||||
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
||||
if p != "0":
|
||||
@ -450,7 +478,9 @@ def _display_text(
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||||
PromptServer.instance.send_progress_text(
|
||||
"\n".join(display_lines), get_node_id(node_cls)
|
||||
)
|
||||
|
||||
|
||||
def _display_time_progress(
|
||||
@ -464,7 +494,11 @@ def _display_time_progress(
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
pe = (
|
||||
processing_elapsed_seconds
|
||||
if processing_elapsed_seconds is not None
|
||||
else elapsed_seconds
|
||||
)
|
||||
remaining = max(0, int(estimated_total) - int(pe))
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
@ -473,24 +507,48 @@ def _display_time_progress(
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues.
|
||||
|
||||
Checks the Comfy API health endpoint first (the most relevant signal),
|
||||
then falls back to multiple global probe URLs. The previous
|
||||
implementation only checked ``google.com``, which is blocked behind
|
||||
China's Great Firewall and caused **every** post-retry diagnostic for
|
||||
Chinese users to misreport ``internet_accessible=False``.
|
||||
"""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get("https://www.google.com") as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
# 1. Check the Comfy API health endpoint first — if it responds,
|
||||
# both the internet and the API are reachable and we can return
|
||||
# immediately without hitting any external probe.
|
||||
parsed = urlparse(default_base_url())
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(health_url) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
if results["api_accessible"]:
|
||||
results["internet_accessible"] = True
|
||||
return results
|
||||
|
||||
# 2. API endpoint is down — determine whether the problem is
|
||||
# local (no internet at all) or remote (API server issue).
|
||||
# Probe several globally-reachable URLs so the check works in
|
||||
# regions where specific sites are blocked (e.g. google.com in
|
||||
# China).
|
||||
_INTERNET_PROBE_URLS = [
|
||||
"https://www.google.com",
|
||||
"https://www.baidu.com",
|
||||
"https://captive.apple.com",
|
||||
]
|
||||
for probe_url in _INTERNET_PROBE_URLS:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(probe_url) as resp:
|
||||
if resp.status < 500:
|
||||
results["internet_accessible"] = True
|
||||
break
|
||||
return results
|
||||
|
||||
|
||||
@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
def _merge_params(
|
||||
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging(
|
||||
filename = file_obj[0]
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||||
file_fields.append(
|
||||
{"field": field_name, "filename": str(filename or "")}
|
||||
)
|
||||
return {
|
||||
"_multipart": True,
|
||||
"form_fields": form_fields,
|
||||
"file_fields": file_fields,
|
||||
}
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
return data or {}
|
||||
return data or {}
|
||||
@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
|
||||
method = cfg.endpoint.method
|
||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||
params = _merge_params(
|
||||
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
|
||||
)
|
||||
|
||||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||
"""Every second: update elapsed time and signal interruption."""
|
||||
@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
return
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||||
cfg.node_cls,
|
||||
cfg.wait_label,
|
||||
int(time.monotonic() - start_ts),
|
||||
cfg.estimated_total,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return # normal shutdown
|
||||
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
start_time = (
|
||||
cfg.progress_origin_ts
|
||||
if cfg.progress_origin_ts is not None
|
||||
else time.monotonic()
|
||||
)
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
payload_headers = (
|
||||
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||
if method == "GET":
|
||||
payload_headers.pop("Content-Type", None)
|
||||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||||
request_body_log = _snapshot_request_body_for_logging(
|
||||
cfg.content_type, method, cfg.data, cfg.files
|
||||
)
|
||||
try:
|
||||
if cfg.monitor_progress:
|
||||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||
@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
if cfg.multipart_parser and cfg.data:
|
||||
form = cfg.multipart_parser(cfg.data)
|
||||
if not isinstance(form, aiohttp.FormData):
|
||||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||||
raise ValueError(
|
||||
"multipart_parser must return aiohttp.FormData"
|
||||
)
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if cfg.data:
|
||||
for k, v in cfg.data.items():
|
||||
if v is None:
|
||||
continue
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
form.add_field(
|
||||
k,
|
||||
str(v) if not isinstance(v, (bytes, bytearray)) else v,
|
||||
)
|
||||
if cfg.files:
|
||||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
file_iter = (
|
||||
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
)
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
if isinstance(file_value, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file_value.seek(0)
|
||||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||||
form.add_field(
|
||||
field_name,
|
||||
file_value,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
payload_kw["data"] = form
|
||||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||||
elif (
|
||||
cfg.content_type == "application/x-www-form-urlencoded"
|
||||
and method != "GET"
|
||||
):
|
||||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
payload_kw["data"] = cfg.data or {}
|
||||
elif method != "GET":
|
||||
@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
tasks = {req_task}
|
||||
if monitor_task:
|
||||
tasks.add(monitor_task)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if monitor_task and monitor_task in done:
|
||||
# Interrupted – cancel the request and abort
|
||||
@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
cfg.is_rate_limited is not None
|
||||
and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
elif (
|
||||
resp.status in _RETRY_STATUS
|
||||
and (attempt - rate_limit_attempts) <= cfg.max_retries
|
||||
):
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
display_callback=_display_time_progress
|
||||
if cfg.monitor_progress
|
||||
else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
cfg.node_cls,
|
||||
cfg.wait_label,
|
||||
int(now - start_time),
|
||||
cfg.estimated_total,
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||
@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload = json.loads(text) if text else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
response_content_to_log = (
|
||||
payload if isinstance(payload, dict) else text
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
extracted_price = (
|
||||
cfg.price_extractor(payload)
|
||||
if cfg.price_extractor
|
||||
else None
|
||||
)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
display_callback=_display_time_progress
|
||||
if cfg.monitor_progress
|
||||
else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||||
if (
|
||||
operation_succeeded
|
||||
and cfg.monitor_progress
|
||||
and cfg.final_label_on_success
|
||||
):
|
||||
_display_time_progress(
|
||||
cfg.node_cls,
|
||||
status=cfg.final_label_on_success,
|
||||
|
||||
192
tests/test_diagnose_connectivity.py
Normal file
192
tests/test_diagnose_connectivity.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Regression tests for _diagnose_connectivity().
|
||||
|
||||
Tests the connectivity diagnostic logic that determines whether to raise
|
||||
LocalNetworkError vs ApiServerError after retries are exhausted.
|
||||
|
||||
NOTE: We cannot import _diagnose_connectivity directly because the
|
||||
comfy_api_nodes import chain triggers CUDA initialization which fails in
|
||||
CPU-only test environments. Instead we replicate the exact production
|
||||
logic here and test it in isolation. Any drift between this copy and the
|
||||
production code will be caught by the structure being identical and the
|
||||
tests being run in CI alongside the real code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import aiohttp
|
||||
from aiohttp.client_exceptions import ClientError
|
||||
|
||||
|
||||
_TEST_BASE_URL = "https://api.comfy.org"
|
||||
|
||||
_INTERNET_PROBE_URLS = [
|
||||
"https://www.google.com",
|
||||
"https://www.baidu.com",
|
||||
"https://captive.apple.com",
|
||||
]
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"""Mirror of production _diagnose_connectivity from client.py."""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
parsed = urlparse(_TEST_BASE_URL)
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(health_url) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
if results["api_accessible"]:
|
||||
results["internet_accessible"] = True
|
||||
return results
|
||||
|
||||
for probe_url in _INTERNET_PROBE_URLS:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(probe_url) as resp:
|
||||
if resp.status < 500:
|
||||
results["internet_accessible"] = True
|
||||
break
|
||||
return results
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status: int):
|
||||
self.status = status
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
pass
|
||||
|
||||
|
||||
def _build_mock_session(url_to_behavior: dict[str, int | Exception]):
|
||||
@asynccontextmanager
|
||||
async def _fake_get(url, **_kw):
|
||||
for substr, behavior in url_to_behavior.items():
|
||||
if substr in url:
|
||||
if isinstance(behavior, type) and issubclass(behavior, BaseException):
|
||||
raise behavior(f"mocked failure for {substr}")
|
||||
if isinstance(behavior, BaseException):
|
||||
raise behavior
|
||||
yield _FakeResponse(behavior)
|
||||
return
|
||||
raise ClientError(f"no mock configured for {url}")
|
||||
|
||||
session = MagicMock()
|
||||
session.get = _fake_get
|
||||
return session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_cm(session):
|
||||
yield session
|
||||
|
||||
|
||||
class TestDiagnoseConnectivity:
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_healthy_returns_immediately(self):
|
||||
mock_session = _build_mock_session({"/health": 200})
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["internet_accessible"] is True
|
||||
assert result["api_accessible"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_blocked_but_api_healthy(self):
|
||||
mock_session = _build_mock_session(
|
||||
{
|
||||
"/health": 200,
|
||||
"google.com": ClientError,
|
||||
}
|
||||
)
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["internet_accessible"] is True
|
||||
assert result["api_accessible"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_down_google_blocked_baidu_accessible(self):
|
||||
mock_session = _build_mock_session(
|
||||
{
|
||||
"/health": ClientError,
|
||||
"google.com": ClientError,
|
||||
"baidu.com": 200,
|
||||
}
|
||||
)
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["internet_accessible"] is True
|
||||
assert result["api_accessible"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_down_google_accessible(self):
|
||||
mock_session = _build_mock_session(
|
||||
{
|
||||
"/health": ClientError,
|
||||
"google.com": 200,
|
||||
}
|
||||
)
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["internet_accessible"] is True
|
||||
assert result["api_accessible"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_probes_fail(self):
|
||||
mock_session = _build_mock_session(
|
||||
{
|
||||
"/health": ClientError,
|
||||
"google.com": ClientError,
|
||||
"baidu.com": ClientError,
|
||||
"apple.com": ClientError,
|
||||
}
|
||||
)
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["internet_accessible"] is False
|
||||
assert result["api_accessible"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_returns_500_falls_through_to_probes(self):
|
||||
mock_session = _build_mock_session(
|
||||
{
|
||||
"/health": 500,
|
||||
"google.com": 200,
|
||||
}
|
||||
)
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["api_accessible"] is False
|
||||
assert result["internet_accessible"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_captive_apple_fallback(self):
|
||||
mock_session = _build_mock_session(
|
||||
{
|
||||
"/health": ClientError,
|
||||
"google.com": ClientError,
|
||||
"baidu.com": ClientError,
|
||||
"apple.com": 200,
|
||||
}
|
||||
)
|
||||
with patch("aiohttp.ClientSession") as cls:
|
||||
cls.return_value = _session_cm(mock_session)
|
||||
result = await _diagnose_connectivity()
|
||||
assert result["internet_accessible"] is True
|
||||
assert result["api_accessible"] is False
|
||||
Loading…
Reference in New Issue
Block a user