diff --git a/app/user_manager.py b/app/user_manager.py index 53dff18b7..95ad157c3 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -5,6 +5,7 @@ import uuid import glob import shutil from aiohttp import web +from urllib import parse from comfy.cli_args import args from folder_paths import user_directory from .app_settings import AppSettings @@ -59,6 +60,10 @@ class UserManager(): return None if file is not None: + # Check if filename is url encoded + if "%" in file: + file = parse.unquote(file) + # prevent leaving /{type}/{user} path = os.path.abspath(os.path.join(user_root, file)) if os.path.commonpath((user_root, path)) != user_root: diff --git a/folder_paths.py b/folder_paths.py index 74a7d527c..b154448fc 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -257,9 +257,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im def compute_vars(input: str, image_width: int, image_height: int) -> str: input = input.replace("%width%", str(image_width)) input = input.replace("%height%", str(image_height)) + now = time.localtime() + input = input.replace("%year%", str(now.tm_year)) + input = input.replace("%month%", str(now.tm_mon).zfill(2)) + input = input.replace("%day%", str(now.tm_mday).zfill(2)) + input = input.replace("%hour%", str(now.tm_hour).zfill(2)) + input = input.replace("%minute%", str(now.tm_min).zfill(2)) + input = input.replace("%second%", str(now.tm_sec).zfill(2)) return input - filename_prefix = compute_vars(filename_prefix, image_width, image_height) + if "%" in filename_prefix: + filename_prefix = compute_vars(filename_prefix, image_width, image_height) subfolder = os.path.dirname(os.path.normpath(filename_prefix)) filename = os.path.basename(os.path.normpath(filename_prefix)) diff --git a/server.py b/server.py index 3611880da..93147f6c5 100644 --- a/server.py +++ b/server.py @@ -80,6 +80,33 @@ def create_cors_middleware(allowed_origin: str): return cors_middleware +def create_origin_only_middleware(): + @web.middleware + async def origin_only_middleware(request: web.Request, handler): + if 'Host' in request.headers and 'Origin' in request.headers: + host = request.headers['Host'] + origin = request.headers['Origin'] + host_domain = host.lower() + parsed = urllib.parse.urlparse(origin) + origin_domain = parsed.netloc.lower() + if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers + result = urllib.parse.urlsplit('//' + host_domain) + host_domain = result.hostname + + if len(host_domain) > 0 and len(origin_domain) > 0: + if host_domain != origin_domain: + logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) + return web.Response(status=403) + + if request.method == "OPTIONS": + response = web.Response() + else: + response = await handler(request) + + return response + + return origin_only_middleware + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -99,6 +126,8 @@ class PromptServer(): middlewares = [cache_control] if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) + else: + middlewares.append(create_origin_only_middleware()) max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)