mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
pixeldit: sanitize and warn on NaN PiD output on AMD fp16/bf16
On AMD ROCm (gfx11xx APUs/GPUs), the PiD PidNet forward in fp16/bf16 can return an all-NaN tensor through the ROCm/AOTriton attention path, which then corrupts the decoded image. The default (non-AOTriton) path stays clean, so the bad values come from an AOTriton attention miscompilation on gfx11xx rather than the PiD math itself (see ROCm/triton#909 and ROCm/aotriton#179). Guard the PidNet output on AMD: when it is fp16/bf16 and actually contains NaN/Inf, log a one-time warning that points at --use-split-cross-attention and clamp the values with nan_to_num before decode (the same pattern already used for flux/lumina). Non-AMD devices and fp32 paths are unaffected; finite outputs only pay a single isfinite() check. Fixes #14249 Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
This commit is contained in:
parent
2cdaaf4a25
commit
70724ddf43
@ -3,16 +3,24 @@ directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
|||||||
body + LQ projection branch injected before each MMDiT patch block.
|
body + LQ projection branch injected before each MMDiT patch block.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
from .model import PixDiT_T2I
|
from .model import PixDiT_T2I
|
||||||
from .modules import precompute_freqs_cis_2d
|
from .modules import precompute_freqs_cis_2d
|
||||||
|
|
||||||
|
|
||||||
|
# Warn at most once per process when the ROCm/AOTriton attention path returns
|
||||||
|
# non-finite values that the guard below sanitizes (see ComfyUI #14249).
|
||||||
|
_PID_AMD_NONFINITE_WARNED = False
|
||||||
|
|
||||||
|
|
||||||
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
||||||
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
||||||
|
|
||||||
@ -217,7 +225,7 @@ class PidNet(PixDiT_T2I):
|
|||||||
|
|
||||||
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
||||||
|
|
||||||
return super()._forward(
|
out = super()._forward(
|
||||||
x, timesteps,
|
x, timesteps,
|
||||||
context=context, attention_mask=attention_mask,
|
context=context, attention_mask=attention_mask,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@ -225,3 +233,15 @@ class PidNet(PixDiT_T2I):
|
|||||||
pid_degrade_sigma=degrade_sigma,
|
pid_degrade_sigma=degrade_sigma,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if comfy.model_management.is_amd() and out.is_floating_point() and out.dtype in (torch.float16, torch.bfloat16) and not torch.isfinite(out).all():
|
||||||
|
global _PID_AMD_NONFINITE_WARNED
|
||||||
|
if not _PID_AMD_NONFINITE_WARNED:
|
||||||
|
logging.warning(
|
||||||
|
"PiD produced non-finite output on AMD; sanitizing NaN/Inf so the "
|
||||||
|
"decoded image stays usable. This is a known ROCm/AOTriton attention "
|
||||||
|
"miscompilation on gfx11xx (ComfyUI #14249); for an unaffected run, "
|
||||||
|
"launch with --use-split-cross-attention."
|
||||||
|
)
|
||||||
|
_PID_AMD_NONFINITE_WARNED = True
|
||||||
|
out = torch.nan_to_num(out, nan=0.0, posinf=65504.0, neginf=-65504.0)
|
||||||
|
return out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user