mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
8d618edad7
@ -5,6 +5,7 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import user_directory
|
||||
from .app_settings import AppSettings
|
||||
@ -59,6 +60,10 @@ class UserManager():
|
||||
return None
|
||||
|
||||
if file is not None:
|
||||
# Check if filename is url encoded
|
||||
if "%" in file:
|
||||
file = parse.unquote(file)
|
||||
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
|
||||
@ -257,9 +257,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
now = time.localtime()
|
||||
input = input.replace("%year%", str(now.tm_year))
|
||||
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
if "%" in filename_prefix:
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
29
server.py
29
server.py
@ -80,6 +80,33 @@ def create_cors_middleware(allowed_origin: str):
|
||||
|
||||
return cors_middleware
|
||||
|
||||
def create_origin_only_middleware():
|
||||
@web.middleware
|
||||
async def origin_only_middleware(request: web.Request, handler):
|
||||
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||
host = request.headers['Host']
|
||||
origin = request.headers['Origin']
|
||||
host_domain = host.lower()
|
||||
parsed = urllib.parse.urlparse(origin)
|
||||
origin_domain = parsed.netloc.lower()
|
||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers
|
||||
result = urllib.parse.urlsplit('//' + host_domain)
|
||||
host_domain = result.hostname
|
||||
|
||||
if len(host_domain) > 0 and len(origin_domain) > 0:
|
||||
if host_domain != origin_domain:
|
||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||
return web.Response(status=403)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
response = web.Response()
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
return response
|
||||
|
||||
return origin_only_middleware
|
||||
|
||||
class PromptServer():
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
@ -99,6 +126,8 @@ class PromptServer():
|
||||
middlewares = [cache_control]
|
||||
if args.enable_cors_header:
|
||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||
else:
|
||||
middlewares.append(create_origin_only_middleware())
|
||||
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user