mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-12 03:22:37 +08:00
fix: prevent autocast crash with fp8 weights in WAN model addcmul ops
PyTorch 2.8's autocast `prioritize` function doesn't handle FP8 ScalarTypes, causing sporadic "Unexpected floating ScalarType in at::autocast::prioritize" RuntimeError when fp8_e4m3fn_fast weights are used with the WAN 2.1 model. This wraps all torch.addcmul calls in the WAN attention blocks with autocast-disabled context when autocast is active, matching the existing pattern used in sub_quadratic_attention.py and other ComfyUI modules. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
359559c913
commit
8e4bc0edac
@ -174,6 +174,15 @@ def repeat_e(e, x):
|
|||||||
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
|
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
|
||||||
|
|
||||||
|
|
||||||
|
def _addcmul(x, y, z):
|
||||||
|
"""torch.addcmul wrapper that disables autocast to avoid
|
||||||
|
'Unexpected floating ScalarType in at::autocast::prioritize' with fp8 weights."""
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
with torch.autocast(device_type=x.device.type, enabled=False):
|
||||||
|
return torch.addcmul(x, y, z)
|
||||||
|
return torch.addcmul(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -242,10 +251,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# self-attention
|
# self-attention
|
||||||
x = x.contiguous() # otherwise implicit in LayerNorm
|
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
_addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
x = _addcmul(x, y, repeat_e(e[2], x))
|
||||||
del y
|
del y
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
@ -255,8 +264,8 @@ class WanAttentionBlock(nn.Module):
|
|||||||
for p in patches["attn2_patch"]:
|
for p in patches["attn2_patch"]:
|
||||||
x = p({"x": x, "transformer_options": transformer_options})
|
x = p({"x": x, "transformer_options": transformer_options})
|
||||||
|
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = _addcmul(x, y, repeat_e(e[5], x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -371,7 +380,7 @@ class Head(nn.Module):
|
|||||||
else:
|
else:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||||
|
|
||||||
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
|
x = (self.head(_addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1453,17 +1462,17 @@ class WanAttentionBlockAudio(WanAttentionBlock):
|
|||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
_addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
x = _addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
|
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = _addcmul(x, y, repeat_e(e[5], x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class DummyAdapterLayer(nn.Module):
|
class DummyAdapterLayer(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user