pixeldit: sanitize and warn on NaN PiD output on AMD fp16/bf16

On AMD ROCm (gfx11xx APUs/GPUs), the PiD PidNet forward in fp16/bf16 can
return an all-NaN tensor through the ROCm/AOTriton attention path, which
then corrupts the decoded image. The default (non-AOTriton) path stays
clean, so the bad values come from an AOTriton attention miscompilation on
gfx11xx rather than the PiD math itself (see ROCm/triton#909 and
ROCm/aotriton#179).

Guard the PidNet output on AMD: when it is fp16/bf16 and actually contains
NaN/Inf, log a one-time warning that points at --use-split-cross-attention
and clamp the values with nan_to_num before decode (the same pattern
already used for flux/lumina). Non-AMD devices and fp32 paths are
unaffected; finite outputs only pay a single isfinite() check.

Fixes #14249

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
This commit is contained in:
liminfei-amd 2026-06-07 12:19:19 +08:00
parent 2cdaaf4a25
commit 70724ddf43

View File

@ -3,16 +3,24 @@ directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block.
"""
import logging
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
from .model import PixDiT_T2I
from .modules import precompute_freqs_cis_2d
# Warn at most once per process when the ROCm/AOTriton attention path returns
# non-finite values that the guard below sanitizes (see ComfyUI #14249).
_PID_AMD_NONFINITE_WARNED = False
class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
@ -217,7 +225,7 @@ class PidNet(PixDiT_T2I):
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
return super()._forward(
out = super()._forward(
x, timesteps,
context=context, attention_mask=attention_mask,
transformer_options=transformer_options,
@ -225,3 +233,15 @@ class PidNet(PixDiT_T2I):
pid_degrade_sigma=degrade_sigma,
**kwargs,
)
if comfy.model_management.is_amd() and out.is_floating_point() and out.dtype in (torch.float16, torch.bfloat16) and not torch.isfinite(out).all():
global _PID_AMD_NONFINITE_WARNED
if not _PID_AMD_NONFINITE_WARNED:
logging.warning(
"PiD produced non-finite output on AMD; sanitizing NaN/Inf so the "
"decoded image stays usable. This is a known ROCm/AOTriton attention "
"miscompilation on gfx11xx (ComfyUI #14249); for an unaffected run, "
"launch with --use-split-cross-attention."
)
_PID_AMD_NONFINITE_WARNED = True
out = torch.nan_to_num(out, nan=0.0, posinf=65504.0, neginf=-65504.0)
return out