mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
Some optimizations to make Ernie inference a bit faster. (#13472)
This commit is contained in:
parent
b9dedea57d
commit
3d816db07f
@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module):
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query, key = query.to(x.dtype), key.to(x.dtype)
|
||||
|
||||
q_flat = query.reshape(B, S, -1)
|
||||
k_flat = key.reshape(B, S, -1)
|
||||
|
||||
@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||
|
||||
residual = x
|
||||
x_norm = self.adaLN_sa_ln(x)
|
||||
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
||||
x_norm = x_norm * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
||||
x = residual + gate_msa * attn_out
|
||||
|
||||
residual = x
|
||||
x_norm = self.adaLN_mlp_ln(x)
|
||||
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
||||
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
|
||||
return residual + gate_mlp * self.mlp(x_norm)
|
||||
|
||||
class ErnieImageAdaLNContinuous(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||
@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module):
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||
x = self.norm(x)
|
||||
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
|
||||
return x
|
||||
|
||||
class ErnieImageModel(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user