From a694c3b58f4cbdbd8331d92fb868767e34ba37a8 Mon Sep 17 00:00:00 2001 From: Hermes Admin Date: Sun, 24 May 2026 12:12:40 +0800 Subject: [PATCH] fix(security): harden origin middleware against DNS rebinding attacks (#12860) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add CSRF token generation (per-process, sent as cookie + response header) - Require CSRF token on state-changing requests (POST/PUT/DELETE/PATCH) from loopback origins — prevents cross-site request forgery via DNS rebinding - Add double-DNS-resolution check (_dns_rebind_check): resolves the Host hostname twice with 200ms delay; rejects if addresses differ (rebind signal) or if any resolved IP is non-loopback - Preserve all existing Sec-Fetch-Site and Host/Origin mismatch checks CVE reference: huntr.com/bounties/f1458e43-64a7-4df2-b71c-9ca453755dc7 --- server.py | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 44470b904..d956a0b6b 100644 --- a/server.py +++ b/server.py @@ -144,6 +144,44 @@ def is_loopback(host): return loopback +# CSRF token for DNS rebinding protection — generated once at server start +_csrf_token = None + + +def get_csrf_token(): + """Return the per-process CSRF token, generating it lazily on first call.""" + global _csrf_token + if _csrf_token is None: + _csrf_token = uuid.uuid4().hex + return _csrf_token + + +def _dns_rebind_check(hostname): + """Resolve *hostname* twice with a 200ms delay. + + Returns True if both resolutions yield the **same** set of loopback addresses. + Returns False when the two results differ (classic DNS rebinding signal) + or when any result contains a non-loopback IP. + """ + try: + r1 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM) + time.sleep(0.2) + r2 = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM) + except socket.gaierror: + return False # unresolvable → reject + + ips1 = sorted({addr[4][0] for addr in r1}) + ips2 = sorted({addr[4][0] for addr in r2}) + + if ips1 != ips2: + return False # addresses changed between lookups → rebind + + for ip in ips1: + if not ipaddress.ip_address(ip).is_loopback: + return False + return True + + def create_origin_only_middleware(): @web.middleware async def origin_only_middleware(request: web.Request, handler): @@ -151,9 +189,8 @@ def create_origin_only_middleware(): sec_fetch_site = request.headers['Sec-Fetch-Site'] if sec_fetch_site == 'cross-site': return web.Response(status=403) - #this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason. - #in that case the Host and Origin hostnames won't match - #I know the proper fix would be to add a cookie but this should take care of the problem in the meantime + + # --- CSRF + DNS-rebinding guard for loopback requests --- if 'Host' in request.headers and 'Origin' in request.headers: host = request.headers['Host'] origin = request.headers['Origin'] @@ -162,10 +199,9 @@ def create_origin_only_middleware(): origin_domain = parsed.netloc.lower() host_domain_parsed = urllib.parse.urlsplit('//' + host_domain) - #limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit loopback = is_loopback(host_domain_parsed.hostname) - if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host + if parsed.port is None: host_domain = host_domain_parsed.hostname if host_domain_parsed.port is None: origin_domain = parsed.hostname @@ -175,11 +211,41 @@ def create_origin_only_middleware(): logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) return web.Response(status=403) + # ── DNS rebinding double-resolution check ── + # Even when Host == Origin, an attacker-controlled domain can + # rebound to 127.0.0.1 after the browser checks the origin. + # Resolving twice detects the IP flip. + if not _dns_rebind_check(host_domain_parsed.hostname or host_domain): + logging.warning("WARNING: possible DNS rebinding detected for %s, returning 403", host_domain) + return web.Response(status=403) + + # ── CSRF token validation for state-changing requests ── + # State-changing requests from loopback must carry the + # server-issued CSRF token (sent as a cookie on every response). + if request.method in ("POST", "PUT", "DELETE", "PATCH"): + token = get_csrf_token() + # Check header first, then fallback to query param + client_token = request.headers.get("X-CSRF-Token") or \ + request.query.get("csrf_token") + if client_token != token: + logging.warning("WARNING: missing/invalid CSRF token for %s %s (loopback), returning 403", + request.method, request.path) + return web.Response(status=403, text="Missing or invalid CSRF token", + content_type="text/plain") + if request.method == "OPTIONS": response = web.Response() else: response = await handler(request) + # Inject CSRF token cookie so same-origin clients can include it + token = get_csrf_token() + response.set_cookie("comfyui_csrf", token, + httponly=False, samesite="Strict", + path="/", max_age=86400 * 7) + # Also expose it via header for non-cookie environments (e.g. API scripts) + response.headers["X-CSRF-Token"] = token + return response return origin_only_middleware