fix: prevent autocast crash in WAN model addcmul ops

Wraps torch.addcmul calls in WAN attention blocks with autocast-disabled
context to prevent 'Unexpected floating ScalarType in at::autocast::prioritize'
RuntimeError. This occurs when upstream nodes (e.g. SAM3) leave CUDA autocast
enabled - PyTorch 2.8's autocast promote dispatch for addcmul hits an unhandled
dtype in the prioritize function.

Uses torch.is_autocast_enabled(device_type) (non-deprecated API) and only
applies the workaround when autocast is actually active (zero overhead otherwise).
This commit is contained in:
Deep Mehta 2026-04-09 20:59:31 -07:00
parent e6be419a30
commit 7db81e43df

View File

@ -174,6 +174,17 @@ def repeat_e(e, x):
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
def _addcmul(x, y, z):
"""torch.addcmul wrapper that disables autocast to prevent
'Unexpected floating ScalarType in at::autocast::prioritize' when
upstream nodes (e.g. SAM3) leave CUDA autocast enabled."""
device_type = x.device.type
if torch.is_autocast_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
return torch.addcmul(x, y, z)
return torch.addcmul(x, y, z)
class WanAttentionBlock(nn.Module):
def __init__(self,
@ -242,10 +253,10 @@ class WanAttentionBlock(nn.Module):
# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
_addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
x = _addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
@ -255,8 +266,8 @@ class WanAttentionBlock(nn.Module):
for p in patches["attn2_patch"]:
x = p({"x": x, "transformer_options": transformer_options})
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = _addcmul(x, y, repeat_e(e[5], x))
return x
@ -371,7 +382,7 @@ class Head(nn.Module):
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
x = (self.head(_addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
return x
@ -1453,17 +1464,17 @@ class WanAttentionBlockAudio(WanAttentionBlock):
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
_addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
x = _addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if audio is not None:
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = _addcmul(x, y, repeat_e(e[5], x))
return x
class DummyAdapterLayer(nn.Module):