mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 01:12:37 +08:00
extend poll_op to use instead of custom async cycle
Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
parent
c08eae33cf
commit
f3701c56bf
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
@ -202,55 +200,31 @@ async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str:
|
||||
ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"),
|
||||
response_model=SeedanceCreateVisualValidateSessionResponse,
|
||||
)
|
||||
|
||||
def _status_text(remaining_sec: int) -> str:
|
||||
return (
|
||||
"Seedance authentication required.\n"
|
||||
f"Open this link in your browser and complete face verification "
|
||||
f"(~{remaining_sec}s left):\n"
|
||||
f"{session.h5_link}"
|
||||
)
|
||||
|
||||
PromptServer.instance.send_progress_text(_status_text(_VERIFICATION_POLL_TIMEOUT_SEC), cls.hidden.unique_id)
|
||||
logger.warning("Seedance authentication required. Open link: %s", session.h5_link)
|
||||
|
||||
deadline = time.monotonic() + _VERIFICATION_POLL_TIMEOUT_SEC
|
||||
last_error: Exception | None = None
|
||||
while time.monotonic() < deadline:
|
||||
await asyncio.sleep(_VERIFICATION_POLL_INTERVAL_SEC)
|
||||
try:
|
||||
result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
|
||||
response_model=SeedanceGetVisualValidateSessionResponse,
|
||||
monitor_progress=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}"
|
||||
|
||||
if result.status == "completed":
|
||||
if not result.group_id:
|
||||
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
|
||||
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
|
||||
)
|
||||
return result.group_id
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
|
||||
response_model=SeedanceGetVisualValidateSessionResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
poll_interval=_VERIFICATION_POLL_INTERVAL_SEC,
|
||||
max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1,
|
||||
estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1,
|
||||
extra_text=h5_text,
|
||||
)
|
||||
|
||||
if result.status == "failed":
|
||||
parts = [f"Seedance authentication failed (session_id={session.session_id})."]
|
||||
if result.error_code:
|
||||
parts.append(f"code={result.error_code}")
|
||||
if result.error_message:
|
||||
parts.append(f"message={result.error_message}")
|
||||
raise RuntimeError(" ".join(parts))
|
||||
if not result.group_id:
|
||||
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
|
||||
|
||||
remaining = max(0, int(deadline - time.monotonic()))
|
||||
PromptServer.instance.send_progress_text(_status_text(remaining), cls.hidden.unique_id)
|
||||
|
||||
hint = f" Last error: {last_error}" if last_error else ""
|
||||
raise RuntimeError(f"Seedance real-person authentication timed out after {_VERIFICATION_POLL_TIMEOUT_SEC}s.{hint}")
|
||||
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
|
||||
)
|
||||
return result.group_id
|
||||
|
||||
|
||||
async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str:
|
||||
|
||||
@ -156,6 +156,7 @@ async def poll_op(
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
@ -176,6 +177,7 @@ async def poll_op(
|
||||
estimated_duration=estimated_duration,
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
extra_text=extra_text,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@ -260,6 +262,7 @@ async def poll_op_raw(
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
@ -299,6 +302,7 @@ async def poll_op_raw(
|
||||
price=state.price,
|
||||
is_queued=state.is_queued,
|
||||
processing_elapsed_seconds=int(proc_elapsed),
|
||||
extra_text=extra_text,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception as exc:
|
||||
@ -389,6 +393,7 @@ async def poll_op_raw(
|
||||
price=state.price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||
extra_text=extra_text,
|
||||
)
|
||||
return resp_json
|
||||
|
||||
@ -462,6 +467,7 @@ def _display_time_progress(
|
||||
price: float | None = None,
|
||||
is_queued: bool | None = None,
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
extra_text: str | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
@ -469,7 +475,8 @@ def _display_time_progress(
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||
_display_text(node_cls, time_line, status=status, price=price)
|
||||
text = f"{time_line}\n\n{extra_text}" if extra_text else time_line
|
||||
_display_text(node_cls, text, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user