pixeldit: sanitize and warn on NaN PiD output on AMD fp16/bf16

On AMD ROCm (gfx11xx APUs/GPUs), the PiD PidNet forward in fp16/bf16 can
return an all-NaN tensor through the ROCm/AOTriton attention path, which
then corrupts the decoded image. The default (non-AOTriton) path stays
clean, so the bad values come from an AOTriton attention miscompilation on
gfx11xx rather than the PiD math itself (see ROCm/triton#909 and
ROCm/aotriton#179).

Guard the PidNet output on AMD: when it is fp16/bf16 and actually contains
NaN/Inf, log a one-time warning that points at --use-split-cross-attention
and clamp the values with nan_to_num before decode (the same pattern
already used for flux/lumina). Non-AMD devices and fp32 paths are
unaffected; finite outputs only pay a single isfinite() check.

Fixes #14249

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
This commit is contained in:
liminfei-amd 2026-06-07 12:19:19 +08:00
parent 2cdaaf4a25
commit 70724ddf43

View File

@ -3,16 +3,24 @@ directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block. body + LQ projection branch injected before each MMDiT patch block.
""" """
import logging
from typing import List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import comfy.model_management
from .model import PixDiT_T2I from .model import PixDiT_T2I
from .modules import precompute_freqs_cis_2d from .modules import precompute_freqs_cis_2d
# Warn at most once per process when the ROCm/AOTriton attention path returns
# non-finite values that the guard below sanitizes (see ComfyUI #14249).
_PID_AMD_NONFINITE_WARNED = False
class SigmaAwareGatePerTokenPerDim(nn.Module): class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
@ -217,7 +225,7 @@ class PidNet(PixDiT_T2I):
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws) lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
return super()._forward( out = super()._forward(
x, timesteps, x, timesteps,
context=context, attention_mask=attention_mask, context=context, attention_mask=attention_mask,
transformer_options=transformer_options, transformer_options=transformer_options,
@ -225,3 +233,15 @@ class PidNet(PixDiT_T2I):
pid_degrade_sigma=degrade_sigma, pid_degrade_sigma=degrade_sigma,
**kwargs, **kwargs,
) )
if comfy.model_management.is_amd() and out.is_floating_point() and out.dtype in (torch.float16, torch.bfloat16) and not torch.isfinite(out).all():
global _PID_AMD_NONFINITE_WARNED
if not _PID_AMD_NONFINITE_WARNED:
logging.warning(
"PiD produced non-finite output on AMD; sanitizing NaN/Inf so the "
"decoded image stays usable. This is a known ROCm/AOTriton attention "
"miscompilation on gfx11xx (ComfyUI #14249); for an unaffected run, "
"launch with --use-split-cross-attention."
)
_PID_AMD_NONFINITE_WARNED = True
out = torch.nan_to_num(out, nan=0.0, posinf=65504.0, neginf=-65504.0)
return out