Resolves SSL errors on Windows when using fsspec

This commit is contained in:
doctorpangloss 2024-05-29 17:34:03 -07:00
parent c56834248c
commit 8a3b49eb49
3 changed files with 37 additions and 10 deletions

View File

@ -6,12 +6,15 @@ import logging
import os
import posixpath
import re
import ssl
import sys
import uuid
from datetime import datetime
from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple
import PIL
import aiohttp
import certifi
import fsspec
import numpy as np
import torch
@ -54,6 +57,17 @@ def is_null_uri(local_uri):
return local_uri == _null_uri or local_uri == "NUL"
async def get_client(**kwargs):
"""
workaround for issues with fsspec on Windows
:param kwargs:
:return:
"""
ssl_context = ssl.create_default_context(cafile=certifi.where())
conn = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(connector=conn, **kwargs)
class FsSpecComfyMetadata(TypedDict, total=True):
prompt_json_str: str
batch_number_str: str
@ -587,8 +601,14 @@ class SaveImagesResponse(CustomNode):
if not uri_is_remote and not os.path.isabs(uri):
uri = os.path.join(output_directory, uri)
abs_path = uri
fsspec_kwargs = {}
if not uri_is_remote:
fsspec_kwargs["auto_mkdir"] = True
# todo: this might need special handling for s3 URLs too
if uri.startswith("http"):
fsspec_kwargs['get_client'] = get_client
try:
with fsspec.open(uri, mode="wb", auto_mkdir=True) as f:
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
image_as_pil.save(f, format=pil_save_format, **additional_args)
if metadata_uri is not None:
# all values are stringified for the metadata
@ -646,18 +666,18 @@ class ImageRequestParameter(CustomNode):
def execute(self, value: str = "", *args, **kwargs) -> ValidatedNodeResult:
output_images = []
f: OpenFile
kwargs_for_fsspec = {}
fsspec_kwargs = {}
if value.startswith('http'):
kwargs_for_fsspec.update({
fsspec_kwargs.update({
"headers": {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
}
},
'get_client': get_client
})
# todo: additional security is needed here to prevent users from accessing local paths
# however this generally needs to be done with user accounts on all OSes
with fsspec.open(value, mode="rb", **kwargs_for_fsspec) as f:
with fsspec.open(value, mode="rb", **fsspec_kwargs) as f:
# from LoadImage
img = Image.open(f)
for i in ImageSequence.Iterator(img):
@ -677,12 +697,10 @@ class ImageRequestParameter(CustomNode):
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
output_images.append(image)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
else:
output_image = output_images[0]
return (output_image,)

View File

@ -49,4 +49,5 @@ opentelemetry-util-http
opentelemetry-instrumentation-aio-pika
opentelemetry-instrumentation-requests
opentelemetry-semantic-conventions
wrapt>=1.16.0
wrapt>=1.16.0
certifi

View File

@ -1,7 +1,6 @@
import os
import pathlib
import re
import sys
import uuid
from datetime import datetime
@ -275,3 +274,12 @@ def test_file_request_parameter(use_temporary_input_directory):
assert loaded_image.shape == load_image_node_rgb.shape
assert torch.allclose(loaded_image, load_image_node_rgb)
def test_file_request_to_http_url_no_exceptions():
n = ImageRequestParameter()
loaded_image, = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg")
_, height, width, channels = loaded_image.shape
assert width == 484
assert height == 480
assert channels == 3