mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Small anima optimization.
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
This commit is contained in:
parent
bd39bbf067
commit
93e3fd4c47
@ -515,7 +515,7 @@ class Block(nn.Module):
|
|||||||
h=H,
|
h=H,
|
||||||
w=W,
|
w=W,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||||
|
|
||||||
def _x_fn(
|
def _x_fn(
|
||||||
_x_B_T_H_W_D: torch.Tensor,
|
_x_B_T_H_W_D: torch.Tensor,
|
||||||
@ -548,7 +548,7 @@ class Block(nn.Module):
|
|||||||
shift_cross_attn_B_T_1_1_D,
|
shift_cross_attn_B_T_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||||
|
|
||||||
normalized_x_B_T_H_W_D = _fn(
|
normalized_x_B_T_H_W_D = _fn(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
@ -557,7 +557,7 @@ class Block(nn.Module):
|
|||||||
shift_mlp_B_T_1_1_D,
|
shift_mlp_B_T_1_1_D,
|
||||||
)
|
)
|
||||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||||
return x_B_T_H_W_D
|
return x_B_T_H_W_D
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user