Small anima optimization. (#14557)

This commit is contained in:
comfyanonymous 2026-06-19 17:05:28 -07:00 committed by GitHub
parent 69d34f2654
commit e00b55631a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -515,7 +515,7 @@ class Block(nn.Module):
h=H, h=H,
w=W, w=W,
) )
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
def _x_fn( def _x_fn(
_x_B_T_H_W_D: torch.Tensor, _x_B_T_H_W_D: torch.Tensor,
@ -548,7 +548,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
normalized_x_B_T_H_W_D = _fn( normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D, x_B_T_H_W_D,
@ -557,7 +557,7 @@ class Block(nn.Module):
shift_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D,
) )
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
return x_B_T_H_W_D return x_B_T_H_W_D