mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Merge 17e315f6e9 into 04879a8113
This commit is contained in:
commit
e0b4fadf0b
12
main.py
12
main.py
@ -37,6 +37,18 @@ if __name__ == "__main__":
|
|||||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||||
os.environ['DO_NOT_TRACK'] = '1'
|
os.environ['DO_NOT_TRACK'] = '1'
|
||||||
|
|
||||||
|
# Early port availability check — fail fast before slow custom node loading
|
||||||
|
import socket as _early_socket
|
||||||
|
_port = getattr(args, 'port', 8188)
|
||||||
|
try:
|
||||||
|
with _early_socket.socket(_early_socket.AF_INET, _early_socket.SOCK_STREAM) as s:
|
||||||
|
s.settimeout(1)
|
||||||
|
if s.connect_ex(('127.0.0.1', _port)) == 0:
|
||||||
|
print(f"ERROR: Port {_port} is already in use. Please free the port or specify a different one with --port.", file=sys.stderr)
|
||||||
|
raise SystemExit(1)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||||
|
|
||||||
import comfy_aimdo.control
|
import comfy_aimdo.control
|
||||||
|
|||||||
76
server.py
76
server.py
@ -144,6 +144,44 @@ def is_loopback(host):
|
|||||||
return loopback
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
# CSRF token for DNS rebinding protection — generated once at server start
|
||||||
|
_csrf_token = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_csrf_token():
|
||||||
|
"""Return the per-process CSRF token, generating it lazily on first call."""
|
||||||
|
global _csrf_token
|
||||||
|
if _csrf_token is None:
|
||||||
|
_csrf_token = uuid.uuid4().hex
|
||||||
|
return _csrf_token
|
||||||
|
|
||||||
|
|
||||||
|
def _dns_rebind_check(hostname):
|
||||||
|
"""Resolve *hostname* twice with a 200ms delay.
|
||||||
|
|
||||||
|
Returns True if both resolutions yield the **same** set of loopback addresses.
|
||||||
|
Returns False when the two results differ (classic DNS rebinding signal)
|
||||||
|
or when any result contains a non-loopback IP.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
r1 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
time.sleep(0.2)
|
||||||
|
r2 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
except socket.gaierror:
|
||||||
|
return False # unresolvable → reject
|
||||||
|
|
||||||
|
ips1 = sorted({addr[4][0] for addr in r1})
|
||||||
|
ips2 = sorted({addr[4][0] for addr in r2})
|
||||||
|
|
||||||
|
if ips1 != ips2:
|
||||||
|
return False # addresses changed between lookups → rebind
|
||||||
|
|
||||||
|
for ip in ips1:
|
||||||
|
if not ipaddress.ip_address(ip).is_loopback:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def create_origin_only_middleware():
|
def create_origin_only_middleware():
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def origin_only_middleware(request: web.Request, handler):
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
@ -151,9 +189,8 @@ def create_origin_only_middleware():
|
|||||||
sec_fetch_site = request.headers['Sec-Fetch-Site']
|
sec_fetch_site = request.headers['Sec-Fetch-Site']
|
||||||
if sec_fetch_site == 'cross-site':
|
if sec_fetch_site == 'cross-site':
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
|
||||||
#in that case the Host and Origin hostnames won't match
|
# --- CSRF + DNS-rebinding guard for loopback requests ---
|
||||||
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
|
||||||
if 'Host' in request.headers and 'Origin' in request.headers:
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
host = request.headers['Host']
|
host = request.headers['Host']
|
||||||
origin = request.headers['Origin']
|
origin = request.headers['Origin']
|
||||||
@ -162,10 +199,9 @@ def create_origin_only_middleware():
|
|||||||
origin_domain = parsed.netloc.lower()
|
origin_domain = parsed.netloc.lower()
|
||||||
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
|
||||||
loopback = is_loopback(host_domain_parsed.hostname)
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
if parsed.port is None:
|
||||||
host_domain = host_domain_parsed.hostname
|
host_domain = host_domain_parsed.hostname
|
||||||
if host_domain_parsed.port is None:
|
if host_domain_parsed.port is None:
|
||||||
origin_domain = parsed.hostname
|
origin_domain = parsed.hostname
|
||||||
@ -175,11 +211,41 @@ def create_origin_only_middleware():
|
|||||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
# ── DNS rebinding double-resolution check ──
|
||||||
|
# Even when Host == Origin, an attacker-controlled domain can
|
||||||
|
# rebound to 127.0.0.1 after the browser checks the origin.
|
||||||
|
# Resolving twice detects the IP flip.
|
||||||
|
if not _dns_rebind_check(host_domain_parsed.hostname or host_domain):
|
||||||
|
logging.warning("WARNING: possible DNS rebinding detected for %s, returning 403", host_domain)
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
# ── CSRF token validation for state-changing requests ──
|
||||||
|
# State-changing requests from loopback must carry the
|
||||||
|
# server-issued CSRF token (sent as a cookie on every response).
|
||||||
|
if request.method in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
|
token = get_csrf_token()
|
||||||
|
# Check header first, then fallback to query param
|
||||||
|
client_token = request.headers.get("X-CSRF-Token") or \
|
||||||
|
request.query.get("csrf_token")
|
||||||
|
if client_token != token:
|
||||||
|
logging.warning("WARNING: missing/invalid CSRF token for %s %s (loopback), returning 403",
|
||||||
|
request.method, request.path)
|
||||||
|
return web.Response(status=403, text="Missing or invalid CSRF token",
|
||||||
|
content_type="text/plain")
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
response = web.Response()
|
response = web.Response()
|
||||||
else:
|
else:
|
||||||
response = await handler(request)
|
response = await handler(request)
|
||||||
|
|
||||||
|
# Inject CSRF token cookie so same-origin clients can include it
|
||||||
|
token = get_csrf_token()
|
||||||
|
response.set_cookie("comfyui_csrf", token,
|
||||||
|
httponly=False, samesite="Strict",
|
||||||
|
path="/", max_age=86400 * 7)
|
||||||
|
# Also expose it via header for non-cookie environments (e.g. API scripts)
|
||||||
|
response.headers["X-CSRF-Token"] = token
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return origin_only_middleware
|
return origin_only_middleware
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user