mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
Fix ace step nan issue on some hardware/pytorch configs. (#12289)
This commit is contained in:
parent
e77b34dfea
commit
26dd7eb421
@ -651,10 +651,10 @@ class Llama2_(nn.Module):
|
|||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
|
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||||
|
|
||||||
if seq_len > 1:
|
if seq_len > 1:
|
||||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
|
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask += causal_mask
|
mask += causal_mask
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user