This commit is contained in:
Jukka Seppänen 2026-02-01 21:50:58 -05:00 committed by GitHub
commit 80b77cc4e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -237,7 +237,7 @@ class BasicAVTransformerBlock(nn.Module):
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
vx += attn1_out * vgate_msa
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
@ -251,7 +251,7 @@ class BasicAVTransformerBlock(nn.Module):
del norm_ax
# audio cross-attention
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
ax += attn1_out * agate_msa
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
@ -263,9 +263,9 @@ class BasicAVTransformerBlock(nn.Module):
# audio to video cross attention
if run_a2v:
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(0, 2))
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(0, 2))
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
@ -275,15 +275,15 @@ class BasicAVTransformerBlock(nn.Module):
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
vx.addcmul_(a2v_out, gate_out_a2v)
vx += a2v_out * gate_out_a2v
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a:
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(2, 4))
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(2, 4))
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
@ -293,7 +293,7 @@ class BasicAVTransformerBlock(nn.Module):
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
ax.addcmul_(v2a_out, gate_out_v2a)
ax += v2a_out * gate_out_v2a
del gate_out_v2a, v2a_out
del vx_norm3, ax_norm3
@ -308,7 +308,7 @@ class BasicAVTransformerBlock(nn.Module):
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx.addcmul_(ff_out, vgate_mlp)
vx += ff_out * vgate_mlp
del vgate_mlp, ff_out
# audio feedforward
@ -321,7 +321,7 @@ class BasicAVTransformerBlock(nn.Module):
del ax_scaled
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax.addcmul_(ff_out, agate_mlp)
ax += ff_out * agate_mlp
del agate_mlp, ff_out
return vx, ax