diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 628916abb..32501dfb2 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -6,12 +6,15 @@ import logging import os import posixpath import re +import ssl import sys import uuid from datetime import datetime from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple import PIL +import aiohttp +import certifi import fsspec import numpy as np import torch @@ -54,6 +57,17 @@ def is_null_uri(local_uri): return local_uri == _null_uri or local_uri == "NUL" +async def get_client(**kwargs): + """ + workaround for issues with fsspec on Windows + :param kwargs: + :return: + """ + ssl_context = ssl.create_default_context(cafile=certifi.where()) + conn = aiohttp.TCPConnector(ssl=ssl_context) + return aiohttp.ClientSession(connector=conn, **kwargs) + + class FsSpecComfyMetadata(TypedDict, total=True): prompt_json_str: str batch_number_str: str @@ -587,8 +601,14 @@ class SaveImagesResponse(CustomNode): if not uri_is_remote and not os.path.isabs(uri): uri = os.path.join(output_directory, uri) abs_path = uri + fsspec_kwargs = {} + if not uri_is_remote: + fsspec_kwargs["auto_mkdir"] = True + # todo: this might need special handling for s3 URLs too + if uri.startswith("http"): + fsspec_kwargs['get_client'] = get_client try: - with fsspec.open(uri, mode="wb", auto_mkdir=True) as f: + with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f: image_as_pil.save(f, format=pil_save_format, **additional_args) if metadata_uri is not None: # all values are stringified for the metadata @@ -646,18 +666,18 @@ class ImageRequestParameter(CustomNode): def execute(self, value: str = "", *args, **kwargs) -> ValidatedNodeResult: output_images = [] - f: OpenFile - kwargs_for_fsspec = {} + fsspec_kwargs = {} if value.startswith('http'): - kwargs_for_fsspec.update({ + fsspec_kwargs.update({ "headers": { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36' - } + }, + 'get_client': get_client }) # todo: additional security is needed here to prevent users from accessing local paths # however this generally needs to be done with user accounts on all OSes - with fsspec.open(value, mode="rb", **kwargs_for_fsspec) as f: + with fsspec.open(value, mode="rb", **fsspec_kwargs) as f: # from LoadImage img = Image.open(f) for i in ImageSequence.Iterator(img): @@ -677,12 +697,10 @@ class ImageRequestParameter(CustomNode): image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] output_images.append(image) - if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) else: output_image = output_images[0] - return (output_image,) diff --git a/requirements.txt b/requirements.txt index d1a1c1c6e..445b09d32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,4 +49,5 @@ opentelemetry-util-http opentelemetry-instrumentation-aio-pika opentelemetry-instrumentation-requests opentelemetry-semantic-conventions -wrapt>=1.16.0 \ No newline at end of file +wrapt>=1.16.0 +certifi \ No newline at end of file diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 6f558aa00..423ca98b7 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -1,7 +1,6 @@ import os import pathlib import re -import sys import uuid from datetime import datetime @@ -275,3 +274,12 @@ def test_file_request_parameter(use_temporary_input_directory): assert loaded_image.shape == load_image_node_rgb.shape assert torch.allclose(loaded_image, load_image_node_rgb) + + +def test_file_request_to_http_url_no_exceptions(): + n = ImageRequestParameter() + loaded_image, = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg") + _, height, width, channels = loaded_image.shape + assert width == 484 + assert height == 480 + assert channels == 3