fix(security): harden origin middleware against DNS rebinding attacks (#12860)

- Add CSRF token generation (per-process, sent as cookie + response header)
- Require CSRF token on state-changing requests (POST/PUT/DELETE/PATCH)
  from loopback origins — prevents cross-site request forgery via DNS rebinding
- Add double-DNS-resolution check (_dns_rebind_check): resolves the Host
  hostname twice with 200ms delay; rejects if addresses differ (rebind signal)
  or if any resolved IP is non-loopback
- Preserve all existing Sec-Fetch-Site and Host/Origin mismatch checks

CVE reference: huntr.com/bounties/f1458e43-64a7-4df2-b71c-9ca453755dc7
This commit is contained in:
Hermes Admin 2026-05-24 12:12:40 +08:00
parent ea62dc11c9
commit a694c3b58f

View File

@ -144,6 +144,44 @@ def is_loopback(host):
return loopback
# CSRF token for DNS rebinding protection — generated once at server start
_csrf_token = None
def get_csrf_token():
"""Return the per-process CSRF token, generating it lazily on first call."""
global _csrf_token
if _csrf_token is None:
_csrf_token = uuid.uuid4().hex
return _csrf_token
def _dns_rebind_check(hostname):
"""Resolve *hostname* twice with a 200ms delay.
Returns True if both resolutions yield the **same** set of loopback addresses.
Returns False when the two results differ (classic DNS rebinding signal)
or when any result contains a non-loopback IP.
"""
try:
r1 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM)
time.sleep(0.2)
r2 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM)
except socket.gaierror:
return False # unresolvable → reject
ips1 = sorted({addr[4][0] for addr in r1})
ips2 = sorted({addr[4][0] for addr in r2})
if ips1 != ips2:
return False # addresses changed between lookups → rebind
for ip in ips1:
if not ipaddress.ip_address(ip).is_loopback:
return False
return True
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
@ -151,9 +189,8 @@ def create_origin_only_middleware():
sec_fetch_site = request.headers['Sec-Fetch-Site']
if sec_fetch_site == 'cross-site':
return web.Response(status=403)
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
#in that case the Host and Origin hostnames won't match
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
# --- CSRF + DNS-rebinding guard for loopback requests ---
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
@ -162,10 +199,9 @@ def create_origin_only_middleware():
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
if parsed.port is None:
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
@ -175,11 +211,41 @@ def create_origin_only_middleware():
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
# ── DNS rebinding double-resolution check ──
# Even when Host == Origin, an attacker-controlled domain can
# rebound to 127.0.0.1 after the browser checks the origin.
# Resolving twice detects the IP flip.
if not _dns_rebind_check(host_domain_parsed.hostname or host_domain):
logging.warning("WARNING: possible DNS rebinding detected for %s, returning 403", host_domain)
return web.Response(status=403)
# ── CSRF token validation for state-changing requests ──
# State-changing requests from loopback must carry the
# server-issued CSRF token (sent as a cookie on every response).
if request.method in ("POST", "PUT", "DELETE", "PATCH"):
token = get_csrf_token()
# Check header first, then fallback to query param
client_token = request.headers.get("X-CSRF-Token") or \
request.query.get("csrf_token")
if client_token != token:
logging.warning("WARNING: missing/invalid CSRF token for %s %s (loopback), returning 403",
request.method, request.path)
return web.Response(status=403, text="Missing or invalid CSRF token",
content_type="text/plain")
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
# Inject CSRF token cookie so same-origin clients can include it
token = get_csrf_token()
response.set_cookie("comfyui_csrf", token,
httponly=False, samesite="Strict",
path="/", max_age=86400 * 7)
# Also expose it via header for non-cookie environments (e.g. API scripts)
response.headers["X-CSRF-Token"] = token
return response
return origin_only_middleware