mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 01:17:24 +08:00
Merge 17e315f6e9 into 04879a8113
This commit is contained in:
commit
e0b4fadf0b
12
main.py
12
main.py
@ -37,6 +37,18 @@ if __name__ == "__main__":
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
# Early port availability check — fail fast before slow custom node loading
|
||||
import socket as _early_socket
|
||||
_port = getattr(args, 'port', 8188)
|
||||
try:
|
||||
with _early_socket.socket(_early_socket.AF_INET, _early_socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
if s.connect_ex(('127.0.0.1', _port)) == 0:
|
||||
print(f"ERROR: Port {_port} is already in use. Please free the port or specify a different one with --port.", file=sys.stderr)
|
||||
raise SystemExit(1)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
76
server.py
76
server.py
@ -144,6 +144,44 @@ def is_loopback(host):
|
||||
return loopback
|
||||
|
||||
|
||||
# CSRF token for DNS rebinding protection — generated once at server start
|
||||
_csrf_token = None
|
||||
|
||||
|
||||
def get_csrf_token():
|
||||
"""Return the per-process CSRF token, generating it lazily on first call."""
|
||||
global _csrf_token
|
||||
if _csrf_token is None:
|
||||
_csrf_token = uuid.uuid4().hex
|
||||
return _csrf_token
|
||||
|
||||
|
||||
def _dns_rebind_check(hostname):
|
||||
"""Resolve *hostname* twice with a 200ms delay.
|
||||
|
||||
Returns True if both resolutions yield the **same** set of loopback addresses.
|
||||
Returns False when the two results differ (classic DNS rebinding signal)
|
||||
or when any result contains a non-loopback IP.
|
||||
"""
|
||||
try:
|
||||
r1 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM)
|
||||
time.sleep(0.2)
|
||||
r2 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
return False # unresolvable → reject
|
||||
|
||||
ips1 = sorted({addr[4][0] for addr in r1})
|
||||
ips2 = sorted({addr[4][0] for addr in r2})
|
||||
|
||||
if ips1 != ips2:
|
||||
return False # addresses changed between lookups → rebind
|
||||
|
||||
for ip in ips1:
|
||||
if not ipaddress.ip_address(ip).is_loopback:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_origin_only_middleware():
|
||||
@web.middleware
|
||||
async def origin_only_middleware(request: web.Request, handler):
|
||||
@ -151,9 +189,8 @@ def create_origin_only_middleware():
|
||||
sec_fetch_site = request.headers['Sec-Fetch-Site']
|
||||
if sec_fetch_site == 'cross-site':
|
||||
return web.Response(status=403)
|
||||
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||
#in that case the Host and Origin hostnames won't match
|
||||
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||
|
||||
# --- CSRF + DNS-rebinding guard for loopback requests ---
|
||||
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||
host = request.headers['Host']
|
||||
origin = request.headers['Origin']
|
||||
@ -162,10 +199,9 @@ def create_origin_only_middleware():
|
||||
origin_domain = parsed.netloc.lower()
|
||||
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||
|
||||
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||
loopback = is_loopback(host_domain_parsed.hostname)
|
||||
|
||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||
if parsed.port is None:
|
||||
host_domain = host_domain_parsed.hostname
|
||||
if host_domain_parsed.port is None:
|
||||
origin_domain = parsed.hostname
|
||||
@ -175,11 +211,41 @@ def create_origin_only_middleware():
|
||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||
return web.Response(status=403)
|
||||
|
||||
# ── DNS rebinding double-resolution check ──
|
||||
# Even when Host == Origin, an attacker-controlled domain can
|
||||
# rebound to 127.0.0.1 after the browser checks the origin.
|
||||
# Resolving twice detects the IP flip.
|
||||
if not _dns_rebind_check(host_domain_parsed.hostname or host_domain):
|
||||
logging.warning("WARNING: possible DNS rebinding detected for %s, returning 403", host_domain)
|
||||
return web.Response(status=403)
|
||||
|
||||
# ── CSRF token validation for state-changing requests ──
|
||||
# State-changing requests from loopback must carry the
|
||||
# server-issued CSRF token (sent as a cookie on every response).
|
||||
if request.method in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
token = get_csrf_token()
|
||||
# Check header first, then fallback to query param
|
||||
client_token = request.headers.get("X-CSRF-Token") or \
|
||||
request.query.get("csrf_token")
|
||||
if client_token != token:
|
||||
logging.warning("WARNING: missing/invalid CSRF token for %s %s (loopback), returning 403",
|
||||
request.method, request.path)
|
||||
return web.Response(status=403, text="Missing or invalid CSRF token",
|
||||
content_type="text/plain")
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
response = web.Response()
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
# Inject CSRF token cookie so same-origin clients can include it
|
||||
token = get_csrf_token()
|
||||
response.set_cookie("comfyui_csrf", token,
|
||||
httponly=False, samesite="Strict",
|
||||
path="/", max_age=86400 * 7)
|
||||
# Also expose it via header for non-cookie environments (e.g. API scripts)
|
||||
response.headers["X-CSRF-Token"] = token
|
||||
|
||||
return response
|
||||
|
||||
return origin_only_middleware
|
||||
|
||||
Loading…
Reference in New Issue
Block a user