fix(api-nodes-vidu): preserve percent-encoding for signed URLs

This commit is contained in:
bigcat88 2025-12-30 15:26:33 +02:00
parent d7111e426a
commit 36f7b0cca7
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
2 changed files with 22 additions and 1 deletions

View File

@ -1,16 +1,22 @@
import asyncio import asyncio
import contextlib import contextlib
import os import os
import re
import time import time
from collections.abc import Callable from collections.abc import Callable
from io import BytesIO from io import BytesIO
from yarl import URL
from comfy.cli_args import args from comfy.cli_args import args
from comfy.model_management import processing_interrupted from comfy.model_management import processing_interrupted
from comfy_api.latest import IO from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted from .common_exceptions import ProcessingInterrupted
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
def is_processing_interrupted() -> bool: def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption.""" """Return True if user/runtime requested interruption."""
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str): if isinstance(path_or_object, str):
return os.path.getsize(path_or_object) return os.path.getsize(path_or_object)
return len(path_or_object.getvalue()) return len(path_or_object.getvalue())
def to_aiohttp_url(url: str) -> URL:
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
escape and no malformed '%' sequences) and contains no raw whitespace/control
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
# Avoid encoded=True if URL contains raw whitespace/control chars
return URL(url)
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
return URL(url, encoded=True)
return URL(url)

View File

@ -19,6 +19,7 @@ from ._helpers import (
get_auth_header, get_auth_header,
is_processing_interrupted, is_processing_interrupted,
sleep_with_interrupt, sleep_with_interrupt,
to_aiohttp_url,
) )
from .client import _diagnose_connectivity from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
@ -94,7 +95,7 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(url, headers=headers)) req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending: if monitor_task in done and req_task in pending: