This commit is contained in:
kijai 2026-03-22 14:17:05 +02:00
parent 3ecf4d3671
commit 8d914af435

View File

@ -701,14 +701,14 @@ class LTXAVModel(LTXVModel):
time_offset = ref_end[-1].item() + tpl time_offset = ref_end[-1].item() + tpl
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_pos = torch.stack([ref_start, ref_end], dim=-1) if p.start_end else ref_start ref_pos = torch.stack([ref_start, ref_end], dim=-1)
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["target_audio_seq_len"] = ax.shape[1]
ax = torch.cat([ref_tokens, ax], dim=1) ax = torch.cat([ref_tokens, ax], dim=1)
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2) a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
ax = self.audio_patchify_proj(ax) ax = self.audio_patchify_proj(ax)
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["total_audio_seq_len"] = ax.shape[1]
# additional_args.update({"av_orig_shape": list(x.shape)}) # additional_args.update({"av_orig_shape": list(x.shape)})
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
@ -750,13 +750,11 @@ class LTXAVModel(LTXVModel):
a_timestep = kwargs.get("a_timestep") a_timestep = kwargs.get("a_timestep")
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0 and a_timestep is not None: if ref_audio_seq_len > 0 and a_timestep is not None:
# Reference tokens must have timestep=0 (clean conditioning, as during training). # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
# Expand scalar/1D timestep to per-token so ref=0 and target=sigma. target_len = kwargs.get("target_audio_seq_len")
target_len = kwargs.get("total_audio_seq_len", 0) - ref_audio_seq_len
if a_timestep.dim() <= 1: if a_timestep.dim() <= 1:
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len) a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
device=a_timestep.device, dtype=a_timestep.dtype)
a_timestep = torch.cat([ref_ts, a_timestep], dim=1) a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None: if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier a_timestep_scaled = a_timestep * self.timestep_scale_multiplier