mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
Initial ID-LoRA support for LTX2
This commit is contained in:
parent
56ff88f951
commit
f02190aaa8
@ -681,7 +681,34 @@ class LTXAVModel(LTXVModel):
|
|||||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||||
|
|
||||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
|
|
||||||
|
# Inject reference audio for ID-LoRA in-context conditioning
|
||||||
|
ref_audio = kwargs.get("ref_audio", None)
|
||||||
|
ref_audio_seq_len = 0
|
||||||
|
if ref_audio is not None:
|
||||||
|
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||||
|
if ref_tokens.shape[0] < ax.shape[0]:
|
||||||
|
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||||
|
ref_audio_seq_len = ref_tokens.shape[1]
|
||||||
|
B = ax.shape[0]
|
||||||
|
|
||||||
|
# Compute negative temporal positions matching ID-LoRA convention:
|
||||||
|
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||||
|
p = self.a_patchifier
|
||||||
|
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||||
|
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||||
|
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||||
|
time_offset = ref_end[-1].item() + tpl
|
||||||
|
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
|
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
|
ref_pos = torch.stack([ref_start, ref_end], dim=-1) if p.start_end else ref_start
|
||||||
|
|
||||||
|
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||||
|
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||||
|
|
||||||
ax = self.audio_patchify_proj(ax)
|
ax = self.audio_patchify_proj(ax)
|
||||||
|
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||||
|
additional_args["total_audio_seq_len"] = ax.shape[1]
|
||||||
|
|
||||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||||
@ -721,6 +748,16 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
|
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||||
|
# Reference tokens must have timestep=0 (clean conditioning, as during training).
|
||||||
|
# Expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||||
|
target_len = kwargs.get("total_audio_seq_len", 0) - ref_audio_seq_len
|
||||||
|
if a_timestep.dim() <= 1:
|
||||||
|
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||||
|
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:],
|
||||||
|
device=a_timestep.device, dtype=a_timestep.dtype)
|
||||||
|
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||||
a_timestep_flat = a_timestep_scaled.flatten()
|
a_timestep_flat = a_timestep_scaled.flatten()
|
||||||
@ -955,6 +992,13 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_embedded_timestep = embedded_timestep[0]
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
a_embedded_timestep = embedded_timestep[1]
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
|
||||||
|
# Trim reference audio tokens before unpatchification
|
||||||
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
|
if ref_audio_seq_len > 0:
|
||||||
|
ax = ax[:, ref_audio_seq_len:]
|
||||||
|
if a_embedded_timestep.shape[1] > 1:
|
||||||
|
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||||
|
|
||||||
# Expand compressed video timestep if needed
|
# Expand compressed video timestep if needed
|
||||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||||
v_embedded_timestep = v_embedded_timestep.expand()
|
v_embedded_timestep = v_embedded_timestep.expand()
|
||||||
|
|||||||
@ -1053,6 +1053,10 @@ class LTXAV(BaseModel):
|
|||||||
if guide_attention_entries is not None:
|
if guide_attention_entries is not None:
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
|
ref_audio = kwargs.get("ref_audio", None)
|
||||||
|
if ref_audio is not None:
|
||||||
|
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import node_helpers
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput(video_latent, audio_latent)
|
return io.NodeOutput(video_latent, audio_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVReferenceAudio(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVReferenceAudio",
|
||||||
|
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||||
|
category="conditioning/audio",
|
||||||
|
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||||
|
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||||
|
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
# Encode reference audio to latents and patchify
|
||||||
|
audio_latents = audio_vae.encode(reference_audio)
|
||||||
|
b, c, t, f = audio_latents.shape
|
||||||
|
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||||
|
ref_audio = {"tokens": ref_tokens}
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||||
|
|
||||||
|
# Patch model with identity guidance
|
||||||
|
m = model.clone()
|
||||||
|
scale = identity_guidance_scale
|
||||||
|
model_sampling = m.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
if scale == 0:
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
sigma = args["sigma"]
|
||||||
|
sigma_ = sigma[0].item()
|
||||||
|
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
x = args["input"]
|
||||||
|
|
||||||
|
# Strip ref_audio from conditioning for the no-reference pass
|
||||||
|
noref_cond = []
|
||||||
|
for entry in cond:
|
||||||
|
new_entry = entry.copy()
|
||||||
|
mc = new_entry.get("model_conds", {}).copy()
|
||||||
|
mc.pop("ref_audio", None)
|
||||||
|
new_entry["model_conds"] = mc
|
||||||
|
noref_cond.append(new_entry)
|
||||||
|
|
||||||
|
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||||
|
args["model"], [noref_cond], x, sigma, model_options
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg_result + (cond_pred - pred_noref) * scale
|
||||||
|
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(m, positive, negative)
|
||||||
|
|
||||||
|
|
||||||
class LtxvExtension(ComfyExtension):
|
class LtxvExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
|||||||
LTXVCropGuides,
|
LTXVCropGuides,
|
||||||
LTXVConcatAVLatent,
|
LTXVConcatAVLatent,
|
||||||
LTXVSeparateAVLatent,
|
LTXVSeparateAVLatent,
|
||||||
|
LTXVReferenceAudio,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user