mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 18:20:26 +08:00
LTX2: Regression fix for changed outputs
This commit is contained in:
parent
09725967cf
commit
daea2bce13
@ -237,7 +237,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
del norm_vx
|
del norm_vx
|
||||||
# video cross-attention
|
# video cross-attention
|
||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
vx.addcmul_(attn1_out, vgate_msa)
|
vx += attn1_out * vgate_msa
|
||||||
del vgate_msa, attn1_out
|
del vgate_msa, attn1_out
|
||||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
del norm_ax
|
del norm_ax
|
||||||
# audio cross-attention
|
# audio cross-attention
|
||||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||||
ax.addcmul_(attn1_out, agate_msa)
|
ax += attn1_out * agate_msa
|
||||||
del agate_msa, attn1_out
|
del agate_msa, attn1_out
|
||||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
|
|
||||||
@ -263,9 +263,9 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
# audio to video cross attention
|
# audio to video cross attention
|
||||||
if run_a2v:
|
if run_a2v:
|
||||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(0, 2))
|
||||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(0, 2))
|
||||||
|
|
||||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||||
@ -275,15 +275,15 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
del vx_scaled, ax_scaled
|
del vx_scaled, ax_scaled
|
||||||
|
|
||||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
vx += a2v_out * gate_out_a2v
|
||||||
del gate_out_a2v, a2v_out
|
del gate_out_a2v, a2v_out
|
||||||
|
|
||||||
# video to audio cross attention
|
# video to audio cross attention
|
||||||
if run_v2a:
|
if run_v2a:
|
||||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(2, 4))
|
||||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(2, 4))
|
||||||
|
|
||||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||||
@ -293,7 +293,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
del ax_scaled, vx_scaled
|
del ax_scaled, vx_scaled
|
||||||
|
|
||||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
ax += v2a_out * gate_out_v2a
|
||||||
del gate_out_v2a, v2a_out
|
del gate_out_v2a, v2a_out
|
||||||
|
|
||||||
del vx_norm3, ax_norm3
|
del vx_norm3, ax_norm3
|
||||||
@ -308,7 +308,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
del vx_scaled
|
del vx_scaled
|
||||||
|
|
||||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||||
vx.addcmul_(ff_out, vgate_mlp)
|
vx += ff_out * vgate_mlp
|
||||||
del vgate_mlp, ff_out
|
del vgate_mlp, ff_out
|
||||||
|
|
||||||
# audio feedforward
|
# audio feedforward
|
||||||
@ -321,7 +321,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
del ax_scaled
|
del ax_scaled
|
||||||
|
|
||||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||||
ax.addcmul_(ff_out, agate_mlp)
|
ax += ff_out * agate_mlp
|
||||||
del agate_mlp, ff_out
|
del agate_mlp, ff_out
|
||||||
|
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user