extend poll_op to use instead of custom async cycle

Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
bigcat88 2026-04-22 14:59:23 +03:00
parent c08eae33cf
commit f3701c56bf
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
2 changed files with 28 additions and 47 deletions

View File

@ -1,8 +1,6 @@
import asyncio
import logging
import math
import re
import time
import torch
from typing_extensions import override
@ -202,55 +200,31 @@ async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str:
ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"),
response_model=SeedanceCreateVisualValidateSessionResponse,
)
def _status_text(remaining_sec: int) -> str:
return (
"Seedance authentication required.\n"
f"Open this link in your browser and complete face verification "
f"(~{remaining_sec}s left):\n"
f"{session.h5_link}"
)
PromptServer.instance.send_progress_text(_status_text(_VERIFICATION_POLL_TIMEOUT_SEC), cls.hidden.unique_id)
logger.warning("Seedance authentication required. Open link: %s", session.h5_link)
deadline = time.monotonic() + _VERIFICATION_POLL_TIMEOUT_SEC
last_error: Exception | None = None
while time.monotonic() < deadline:
await asyncio.sleep(_VERIFICATION_POLL_INTERVAL_SEC)
try:
result = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
response_model=SeedanceGetVisualValidateSessionResponse,
monitor_progress=False,
)
except Exception as exc:
last_error = exc
continue
h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}"
if result.status == "completed":
if not result.group_id:
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
PromptServer.instance.send_progress_text(
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
)
return result.group_id
result = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
response_model=SeedanceGetVisualValidateSessionResponse,
status_extractor=lambda r: r.status,
completed_statuses=["completed"],
failed_statuses=["failed"],
poll_interval=_VERIFICATION_POLL_INTERVAL_SEC,
max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1,
estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1,
extra_text=h5_text,
)
if result.status == "failed":
parts = [f"Seedance authentication failed (session_id={session.session_id})."]
if result.error_code:
parts.append(f"code={result.error_code}")
if result.error_message:
parts.append(f"message={result.error_message}")
raise RuntimeError(" ".join(parts))
if not result.group_id:
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
remaining = max(0, int(deadline - time.monotonic()))
PromptServer.instance.send_progress_text(_status_text(remaining), cls.hidden.unique_id)
hint = f" Last error: {last_error}" if last_error else ""
raise RuntimeError(f"Seedance real-person authentication timed out after {_VERIFICATION_POLL_TIMEOUT_SEC}s.{hint}")
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
PromptServer.instance.send_progress_text(
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
)
return result.group_id
async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str:

View File

@ -156,6 +156,7 @@ async def poll_op(
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
extra_text: str | None = None,
) -> M:
raw = await poll_op_raw(
cls,
@ -176,6 +177,7 @@ async def poll_op(
estimated_duration=estimated_duration,
cancel_endpoint=cancel_endpoint,
cancel_timeout=cancel_timeout,
extra_text=extra_text,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@ -260,6 +262,7 @@ async def poll_op_raw(
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
extra_text: str | None = None,
) -> dict[str, Any]:
"""
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
@ -299,6 +302,7 @@ async def poll_op_raw(
price=state.price,
is_queued=state.is_queued,
processing_elapsed_seconds=int(proc_elapsed),
extra_text=extra_text,
)
await asyncio.sleep(1.0)
except Exception as exc:
@ -389,6 +393,7 @@ async def poll_op_raw(
price=state.price,
is_queued=False,
processing_elapsed_seconds=int(state.base_processing_elapsed),
extra_text=extra_text,
)
return resp_json
@ -462,6 +467,7 @@ def _display_time_progress(
price: float | None = None,
is_queued: bool | None = None,
processing_elapsed_seconds: int | None = None,
extra_text: str | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@ -469,7 +475,8 @@ def _display_time_progress(
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
_display_text(node_cls, time_line, status=status, price=price)
text = f"{time_line}\n\n{extra_text}" if extra_text else time_line
_display_text(node_cls, text, status=status, price=price)
async def _diagnose_connectivity() -> dict[str, bool]: