improve UX for batch uploads in upload_images_to_comfyapi (#10913)

This commit is contained in:
Alexander Piskun 2025-11-26 19:23:14 +02:00 committed by GitHub
parent 8938aa3f30
commit 1105e0d139
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import logging
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional, Union from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
@ -48,8 +48,9 @@ async def upload_images_to_comfyapi(
image: torch.Tensor, image: torch.Tensor,
*, *,
max_images: int = 8, max_images: int = 8,
mime_type: Optional[str] = None, mime_type: str | None = None,
wait_label: Optional[str] = "Uploading", wait_label: str | None = "Uploading",
show_batch_index: bool = True,
) -> list[str]: ) -> list[str]:
""" """
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
@ -59,11 +60,18 @@ async def upload_images_to_comfyapi(
download_urls: list[str] = [] download_urls: list[str] = []
is_batch = len(image.shape) > 3 is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1 batch_len = image.shape[0] if is_batch else 1
num_to_upload = min(batch_len, max_images)
batch_start_ts = time.monotonic()
for idx in range(min(batch_len, max_images)): for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type) img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
download_urls.append(url) download_urls.append(url)
return download_urls return download_urls
@ -126,8 +134,9 @@ async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
file_bytes_io: BytesIO, file_bytes_io: BytesIO,
filename: str, filename: str,
upload_mime_type: Optional[str], upload_mime_type: str | None,
wait_label: Optional[str] = "Uploading", wait_label: str | None = "Uploading",
progress_origin_ts: float | None = None,
) -> str: ) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL.""" """Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None: if upload_mime_type is None:
@ -148,6 +157,7 @@ async def upload_file_to_comfyapi(
file_bytes_io, file_bytes_io,
content_type=upload_mime_type, content_type=upload_mime_type,
wait_label=wait_label, wait_label=wait_label,
progress_origin_ts=progress_origin_ts,
) )
return create_resp.download_url return create_resp.download_url
@ -155,27 +165,18 @@ async def upload_file_to_comfyapi(
async def upload_file( async def upload_file(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
upload_url: str, upload_url: str,
file: Union[BytesIO, str], file: BytesIO | str,
*, *,
content_type: Optional[str] = None, content_type: str | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: Optional[str] = None, wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None: ) -> None:
""" """
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
Args:
cls: Node class (provides auth context + UI progress hooks).
upload_url: Pre-signed PUT URL.
file: BytesIO or path string.
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
max_retries: Maximum retry attempts.
retry_delay: Initial delay in seconds.
retry_backoff: Exponential backoff factor.
wait_label: Progress label shown in Comfy UI.
Raises: Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
""" """
@ -198,7 +199,7 @@ async def upload_file(
attempt = 0 attempt = 0
delay = retry_delay delay = retry_delay
start_ts = time.monotonic() start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
op_uuid = uuid.uuid4().hex[:8] op_uuid = uuid.uuid4().hex[:8]
while True: while True:
attempt += 1 attempt += 1