mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 23:42:36 +08:00
Some optimizations to make Ernie inference a bit faster. (#13472)
This commit is contained in:
parent
b9dedea57d
commit
3d816db07f
@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module):
|
|||||||
query = apply_rotary_emb(query, image_rotary_emb)
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
key = apply_rotary_emb(key, image_rotary_emb)
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
query, key = query.to(x.dtype), key.to(x.dtype)
|
|
||||||
|
|
||||||
q_flat = query.reshape(B, S, -1)
|
q_flat = query.reshape(B, S, -1)
|
||||||
k_flat = key.reshape(B, S, -1)
|
k_flat = key.reshape(B, S, -1)
|
||||||
|
|
||||||
@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module):
|
|||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
x_norm = self.adaLN_sa_ln(x)
|
x_norm = self.adaLN_sa_ln(x)
|
||||||
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
x_norm = x_norm * (1 + scale_msa) + shift_msa
|
||||||
|
|
||||||
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||||
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
x = residual + gate_msa * attn_out
|
||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
x_norm = self.adaLN_mlp_ln(x)
|
x_norm = self.adaLN_mlp_ln(x)
|
||||||
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
|
||||||
|
|
||||||
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
|
return residual + gate_mlp * self.mlp(x_norm)
|
||||||
|
|
||||||
class ErnieImageAdaLNContinuous(nn.Module):
|
class ErnieImageAdaLNContinuous(nn.Module):
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||||
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class ErnieImageModel(nn.Module):
|
class ErnieImageModel(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user