From 8d914af4354d202ddb0d88e04c5468105d4fd0f7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 22 Mar 2026 14:17:05 +0200 Subject: [PATCH] cleanup --- comfy/ldm/lightricks/av_model.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index a4f288c36..6f2ba41ef 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -701,14 +701,14 @@ class LTXAVModel(LTXVModel): time_offset = ref_end[-1].item() + tpl ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) - ref_pos = torch.stack([ref_start, ref_end], dim=-1) if p.start_end else ref_start + ref_pos = torch.stack([ref_start, ref_end], dim=-1) + additional_args["ref_audio_seq_len"] = ref_audio_seq_len + additional_args["target_audio_seq_len"] = ax.shape[1] ax = torch.cat([ref_tokens, ax], dim=1) a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2) ax = self.audio_patchify_proj(ax) - additional_args["ref_audio_seq_len"] = ref_audio_seq_len - additional_args["total_audio_seq_len"] = ax.shape[1] # additional_args.update({"av_orig_shape": list(x.shape)}) return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args @@ -750,13 +750,11 @@ class LTXAVModel(LTXVModel): a_timestep = kwargs.get("a_timestep") ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) if ref_audio_seq_len > 0 and a_timestep is not None: - # Reference tokens must have timestep=0 (clean conditioning, as during training). - # Expand scalar/1D timestep to per-token so ref=0 and target=sigma. - target_len = kwargs.get("total_audio_seq_len", 0) - ref_audio_seq_len + # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma. + target_len = kwargs.get("target_audio_seq_len") if a_timestep.dim() <= 1: a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len) - ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], - device=a_timestep.device, dtype=a_timestep.dtype) + ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype) a_timestep = torch.cat([ref_ts, a_timestep], dim=1) if a_timestep is not None: a_timestep_scaled = a_timestep * self.timestep_scale_multiplier