diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py index 21b73907a..db3653e7f 100644 --- a/comfy/ldm/pixeldit/pid.py +++ b/comfy/ldm/pixeldit/pid.py @@ -3,16 +3,24 @@ directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I body + LQ projection branch injected before each MMDiT patch block. """ +import logging from typing import List import torch import torch.nn as nn import torch.nn.functional as F +import comfy.model_management + from .model import PixDiT_T2I from .modules import precompute_freqs_cis_2d +# Warn at most once per process when the ROCm/AOTriton attention path returns +# non-finite values that the guard below sanitizes (see ComfyUI #14249). +_PID_AMD_NONFINITE_WARNED = False + + class SigmaAwareGatePerTokenPerDim(nn.Module): """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. @@ -217,7 +225,7 @@ class PidNet(PixDiT_T2I): lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws) - return super()._forward( + out = super()._forward( x, timesteps, context=context, attention_mask=attention_mask, transformer_options=transformer_options, @@ -225,3 +233,15 @@ class PidNet(PixDiT_T2I): pid_degrade_sigma=degrade_sigma, **kwargs, ) + if comfy.model_management.is_amd() and out.is_floating_point() and out.dtype in (torch.float16, torch.bfloat16) and not torch.isfinite(out).all(): + global _PID_AMD_NONFINITE_WARNED + if not _PID_AMD_NONFINITE_WARNED: + logging.warning( + "PiD produced non-finite output on AMD; sanitizing NaN/Inf so the " + "decoded image stays usable. This is a known ROCm/AOTriton attention " + "miscompilation on gfx11xx (ComfyUI #14249); for an unaffected run, " + "launch with --use-split-cross-attention." + ) + _PID_AMD_NONFINITE_WARNED = True + out = torch.nan_to_num(out, nan=0.0, posinf=65504.0, neginf=-65504.0) + return out