mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
cleanup
This commit is contained in:
parent
3ecf4d3671
commit
8d914af435
@ -701,14 +701,14 @@ class LTXAVModel(LTXVModel):
|
|||||||
time_offset = ref_end[-1].item() + tpl
|
time_offset = ref_end[-1].item() + tpl
|
||||||
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
ref_pos = torch.stack([ref_start, ref_end], dim=-1) if p.start_end else ref_start
|
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||||
|
|
||||||
|
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||||
|
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||||
ax = torch.cat([ref_tokens, ax], dim=1)
|
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||||
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||||
|
|
||||||
ax = self.audio_patchify_proj(ax)
|
ax = self.audio_patchify_proj(ax)
|
||||||
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
|
||||||
additional_args["total_audio_seq_len"] = ax.shape[1]
|
|
||||||
|
|
||||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||||
@ -750,13 +750,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
if ref_audio_seq_len > 0 and a_timestep is not None:
|
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||||
# Reference tokens must have timestep=0 (clean conditioning, as during training).
|
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||||
# Expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
target_len = kwargs.get("target_audio_seq_len")
|
||||||
target_len = kwargs.get("total_audio_seq_len", 0) - ref_audio_seq_len
|
|
||||||
if a_timestep.dim() <= 1:
|
if a_timestep.dim() <= 1:
|
||||||
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||||
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:],
|
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||||
device=a_timestep.device, dtype=a_timestep.dtype)
|
|
||||||
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user