diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 671fe834d..aec874815 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -515,7 +515,7 @@ class Block(nn.Module): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -548,7 +548,7 @@ class Block(nn.Module): shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -557,7 +557,7 @@ class Block(nn.Module): shift_mlp_B_T_1_1_D, ) result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) return x_B_T_H_W_D