diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2c6954ecd..d5b7ea0aa 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -237,7 +237,7 @@ class BasicAVTransformerBlock(nn.Module): del norm_vx # video cross-attention vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] - vx.addcmul_(attn1_out, vgate_msa) + vx += attn1_out * vgate_msa del vgate_msa, attn1_out vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) @@ -251,7 +251,7 @@ class BasicAVTransformerBlock(nn.Module): del norm_ax # audio cross-attention agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] - ax.addcmul_(attn1_out, agate_msa) + ax += attn1_out * agate_msa del agate_msa, attn1_out ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) @@ -263,9 +263,9 @@ class BasicAVTransformerBlock(nn.Module): # audio to video cross attention if run_a2v: scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values( - self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2] + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(0, 2)) scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values( - self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2] + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(0, 2)) vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v @@ -275,15 +275,15 @@ class BasicAVTransformerBlock(nn.Module): del vx_scaled, ax_scaled gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0] - vx.addcmul_(a2v_out, gate_out_a2v) + vx += a2v_out * gate_out_a2v del gate_out_a2v, a2v_out # video to audio cross attention if run_v2a: scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values( - self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4] + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(2, 4)) scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values( - self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4] + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(2, 4)) ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a @@ -293,7 +293,7 @@ class BasicAVTransformerBlock(nn.Module): del ax_scaled, vx_scaled gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0] - ax.addcmul_(v2a_out, gate_out_v2a) + ax += v2a_out * gate_out_v2a del gate_out_v2a, v2a_out del vx_norm3, ax_norm3 @@ -308,7 +308,7 @@ class BasicAVTransformerBlock(nn.Module): del vx_scaled vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0] - vx.addcmul_(ff_out, vgate_mlp) + vx += ff_out * vgate_mlp del vgate_mlp, ff_out # audio feedforward @@ -321,7 +321,7 @@ class BasicAVTransformerBlock(nn.Module): del ax_scaled agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0] - ax.addcmul_(ff_out, agate_mlp) + ax += ff_out * agate_mlp del agate_mlp, ff_out return vx, ax