Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-09-09 10:32:25 +03:00 committed by GitHub
commit 8d618edad7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import uuid
import glob
import shutil
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
from folder_paths import user_directory
from .app_settings import AppSettings
@ -59,6 +60,10 @@ class UserManager():
return None
if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:

View File

@ -257,9 +257,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))

View File

@ -80,6 +80,33 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
host_domain = host.lower()
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers
result = urllib.parse.urlsplit('//' + host_domain)
host_domain = result.hostname
if len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
return response
return origin_only_middleware
class PromptServer():
def __init__(self, loop):
PromptServer.instance = self
@ -99,6 +126,8 @@ class PromptServer():
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
middlewares.append(create_origin_only_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)