From 7db81e43df90358ed057fce829d8448e4707f273 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 9 Apr 2026 20:59:31 -0700 Subject: [PATCH] fix: prevent autocast crash in WAN model addcmul ops Wraps torch.addcmul calls in WAN attention blocks with autocast-disabled context to prevent 'Unexpected floating ScalarType in at::autocast::prioritize' RuntimeError. This occurs when upstream nodes (e.g. SAM3) leave CUDA autocast enabled - PyTorch 2.8's autocast promote dispatch for addcmul hits an unhandled dtype in the prioritize function. Uses torch.is_autocast_enabled(device_type) (non-deprecated API) and only applies the workaround when autocast is actually active (zero overhead otherwise). --- comfy/ldm/wan/model.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index b2287dba9..7d2933f94 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -174,6 +174,17 @@ def repeat_e(e, x): return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)] +def _addcmul(x, y, z): + """torch.addcmul wrapper that disables autocast to prevent + 'Unexpected floating ScalarType in at::autocast::prioritize' when + upstream nodes (e.g. SAM3) leave CUDA autocast enabled.""" + device_type = x.device.type + if torch.is_autocast_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): + return torch.addcmul(x, y, z) + return torch.addcmul(x, y, z) + + class WanAttentionBlock(nn.Module): def __init__(self, @@ -242,10 +253,10 @@ class WanAttentionBlock(nn.Module): # self-attention x = x.contiguous() # otherwise implicit in LayerNorm y = self.self_attn( - torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + _addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options) - x = torch.addcmul(x, y, repeat_e(e[2], x)) + x = _addcmul(x, y, repeat_e(e[2], x)) del y # cross-attention & ffn @@ -255,8 +266,8 @@ class WanAttentionBlock(nn.Module): for p in patches["attn2_patch"]: x = p({"x": x, "transformer_options": transformer_options}) - y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) - x = torch.addcmul(x, y, repeat_e(e[5], x)) + y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = _addcmul(x, y, repeat_e(e[5], x)) return x @@ -371,7 +382,7 @@ class Head(nn.Module): else: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) - x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x)))) + x = (self.head(_addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x)))) return x @@ -1453,17 +1464,17 @@ class WanAttentionBlockAudio(WanAttentionBlock): # self-attention y = self.self_attn( - torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + _addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options) - x = torch.addcmul(x, y, repeat_e(e[2], x)) + x = _addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) if audio is not None: x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options) - y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) - x = torch.addcmul(x, y, repeat_e(e[5], x)) + y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = _addcmul(x, y, repeat_e(e[5], x)) return x class DummyAdapterLayer(nn.Module):