Resolves SSL errors on Windows when using fsspec

This commit is contained in:
doctorpangloss 2024-05-29 17:34:03 -07:00
parent c56834248c
commit 8a3b49eb49
3 changed files with 37 additions and 10 deletions

View File

@ -6,12 +6,15 @@ import logging
import os import os
import posixpath import posixpath
import re import re
import ssl
import sys import sys
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple
import PIL import PIL
import aiohttp
import certifi
import fsspec import fsspec
import numpy as np import numpy as np
import torch import torch
@ -54,6 +57,17 @@ def is_null_uri(local_uri):
return local_uri == _null_uri or local_uri == "NUL" return local_uri == _null_uri or local_uri == "NUL"
async def get_client(**kwargs):
"""
workaround for issues with fsspec on Windows
:param kwargs:
:return:
"""
ssl_context = ssl.create_default_context(cafile=certifi.where())
conn = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(connector=conn, **kwargs)
class FsSpecComfyMetadata(TypedDict, total=True): class FsSpecComfyMetadata(TypedDict, total=True):
prompt_json_str: str prompt_json_str: str
batch_number_str: str batch_number_str: str
@ -587,8 +601,14 @@ class SaveImagesResponse(CustomNode):
if not uri_is_remote and not os.path.isabs(uri): if not uri_is_remote and not os.path.isabs(uri):
uri = os.path.join(output_directory, uri) uri = os.path.join(output_directory, uri)
abs_path = uri abs_path = uri
fsspec_kwargs = {}
if not uri_is_remote:
fsspec_kwargs["auto_mkdir"] = True
# todo: this might need special handling for s3 URLs too
if uri.startswith("http"):
fsspec_kwargs['get_client'] = get_client
try: try:
with fsspec.open(uri, mode="wb", auto_mkdir=True) as f: with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
image_as_pil.save(f, format=pil_save_format, **additional_args) image_as_pil.save(f, format=pil_save_format, **additional_args)
if metadata_uri is not None: if metadata_uri is not None:
# all values are stringified for the metadata # all values are stringified for the metadata
@ -646,18 +666,18 @@ class ImageRequestParameter(CustomNode):
def execute(self, value: str = "", *args, **kwargs) -> ValidatedNodeResult: def execute(self, value: str = "", *args, **kwargs) -> ValidatedNodeResult:
output_images = [] output_images = []
f: OpenFile f: OpenFile
kwargs_for_fsspec = {} fsspec_kwargs = {}
if value.startswith('http'): if value.startswith('http'):
kwargs_for_fsspec.update({ fsspec_kwargs.update({
"headers": { "headers": {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36' 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
} },
'get_client': get_client
}) })
# todo: additional security is needed here to prevent users from accessing local paths # todo: additional security is needed here to prevent users from accessing local paths
# however this generally needs to be done with user accounts on all OSes # however this generally needs to be done with user accounts on all OSes
with fsspec.open(value, mode="rb", **kwargs_for_fsspec) as f: with fsspec.open(value, mode="rb", **fsspec_kwargs) as f:
# from LoadImage # from LoadImage
img = Image.open(f) img = Image.open(f)
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
@ -677,12 +697,10 @@ class ImageRequestParameter(CustomNode):
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
output_images.append(image) output_images.append(image)
if len(output_images) > 1: if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0) output_image = torch.cat(output_images, dim=0)
else: else:
output_image = output_images[0] output_image = output_images[0]
return (output_image,) return (output_image,)

View File

@ -49,4 +49,5 @@ opentelemetry-util-http
opentelemetry-instrumentation-aio-pika opentelemetry-instrumentation-aio-pika
opentelemetry-instrumentation-requests opentelemetry-instrumentation-requests
opentelemetry-semantic-conventions opentelemetry-semantic-conventions
wrapt>=1.16.0 wrapt>=1.16.0
certifi

View File

@ -1,7 +1,6 @@
import os import os
import pathlib import pathlib
import re import re
import sys
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -275,3 +274,12 @@ def test_file_request_parameter(use_temporary_input_directory):
assert loaded_image.shape == load_image_node_rgb.shape assert loaded_image.shape == load_image_node_rgb.shape
assert torch.allclose(loaded_image, load_image_node_rgb) assert torch.allclose(loaded_image, load_image_node_rgb)
def test_file_request_to_http_url_no_exceptions():
n = ImageRequestParameter()
loaded_image, = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg")
_, height, width, channels = loaded_image.shape
assert width == 484
assert height == 480
assert channels == 3