diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8e7912e6d..2020326c2 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -386,7 +386,7 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) ref_num_tokens.append(kontext.shape[1]) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 08d686b7b..6f2ba41ef 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel): additional_args["has_spatial_mask"] = has_spatial_mask ax, a_latent_coords = self.a_patchifier.patchify(ax) + + # Inject reference audio for ID-LoRA in-context conditioning + ref_audio = kwargs.get("ref_audio", None) + ref_audio_seq_len = 0 + if ref_audio is not None: + ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device) + if ref_tokens.shape[0] < ax.shape[0]: + ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1) + ref_audio_seq_len = ref_tokens.shape[1] + B = ax.shape[0] + + # Compute negative temporal positions matching ID-LoRA convention: + # offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0 + p = self.a_patchifier + tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate + ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device) + ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device) + time_offset = ref_end[-1].item() + tpl + ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_pos = torch.stack([ref_start, ref_end], dim=-1) + + additional_args["ref_audio_seq_len"] = ref_audio_seq_len + additional_args["target_audio_seq_len"] = ax.shape[1] + ax = torch.cat([ref_tokens, ax], dim=1) + a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2) + ax = self.audio_patchify_proj(ax) # additional_args.update({"av_orig_shape": list(x.shape)}) @@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel): # Prepare audio timestep a_timestep = kwargs.get("a_timestep") + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0 and a_timestep is not None: + # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma. + target_len = kwargs.get("target_audio_seq_len") + if a_timestep.dim() <= 1: + a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len) + ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype) + a_timestep = torch.cat([ref_ts, a_timestep], dim=1) if a_timestep is not None: a_timestep_scaled = a_timestep * self.timestep_scale_multiplier a_timestep_flat = a_timestep_scaled.flatten() @@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel): v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + # Trim reference audio tokens before unpatchification + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0: + ax = ax[:, ref_audio_seq_len:] + if a_embedded_timestep.shape[1] > 1: + a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:] + # Expand compressed video timestep if needed if isinstance(v_embedded_timestep, CompressedTimestep): v_embedded_timestep = v_embedded_timestep.expand() diff --git a/comfy/model_base.py b/comfy/model_base.py index 43ec93324..70aff886e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -937,9 +937,10 @@ class LongCatImage(Flux): transformer_options = transformer_options.copy() rope_opts = transformer_options.get("rope_options", {}) rope_opts = dict(rope_opts) + pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0 rope_opts.setdefault("shift_t", 1.0) - rope_opts.setdefault("shift_y", 512.0) - rope_opts.setdefault("shift_x", 512.0) + rope_opts.setdefault("shift_y", pe_len) + rope_opts.setdefault("shift_x", pe_len) transformer_options["rope_options"] = rope_opts return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) @@ -1060,6 +1061,10 @@ class LTXAV(BaseModel): if guide_attention_entries is not None: out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + ref_audio = kwargs.get("ref_audio", None) + if ref_audio is not None: + out['ref_audio'] = comfy.conds.CONDConstant(ref_audio) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ccc200b7a..9fdea999c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): grid = e.get("extra", None) start = e.get("index") if position_ids is None: - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) position_ids[:, :start] = torch.arange(0, start, device=embeds.device) end = e.get("size") + start len_max = int(grid.max()) // 2 start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + if attention_mask is not None: + # Assign compact sequential positions to attended tokens only, + # skipping over padding so post-padding tokens aren't inflated. + after_mask = attention_mask[0, end:] + text_positions = after_mask.cumsum(0) - 1 + start_next + offset + position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:]) + else: + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) position_ids[0, start:end] = start + offset max_d = int(grid[0][1]) // 2 position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py index 882d80901..0962779e3 100644 --- a/comfy/text_encoders/longcat_image.py +++ b/comfy/text_encoders/longcat_image.py @@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): return [output] +IMAGE_PAD_TOKEN_ID = 151655 + class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + SUFFIX = "<|im_end|>\n<|im_start|>assistant\n" + def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__( embedding_directory=embedding_directory, @@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): name="qwen25_7b", tokenizer=LongCatImageBaseTokenizer, ) - self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" - self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs): skip_template = False if text.startswith("<|im_start|>"): skip_template = True @@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): text, return_word_ids=return_word_ids, disable_weights=True, **kwargs ) else: + has_images = images is not None and len(images) > 0 + template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX + prefix_ids = base_tok.tokenizer( - self.longcat_template_prefix, add_special_tokens=False + template_prefix, add_special_tokens=False )["input_ids"] suffix_ids = base_tok.tokenizer( - self.longcat_template_suffix, add_special_tokens=False + self.SUFFIX, add_special_tokens=False )["input_ids"] prompt_tokens = base_tok.tokenize_with_weights( @@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): suffix_pairs = [(t, 1.0) for t in suffix_ids] combined = prefix_pairs + prompt_pairs + suffix_pairs + + if has_images: + embed_count = 0 + for i in range(len(combined)): + if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images): + combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1]) + embed_count += 1 + tokens = {"qwen25_7b": [combined]} return tokens diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py index 3b18ce730..98c350a12 100644 --- a/comfy/text_encoders/qwen_vl.py +++ b/comfy/text_encoders/qwen_vl.py @@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module): hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = self.merger(hidden_states) + # Potentially important for spatially precise edits. This is present in the HF implementation. + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] return hidden_states diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index c05571143..d7c2e8744 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.model_sampling +import comfy.samplers import comfy.utils import math import numpy as np @@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode): return io.NodeOutput(video_latent, audio_latent) +class LTXVReferenceAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVReferenceAudio", + display_name="LTXV Reference Audio (ID-LoRA)", + category="conditioning/audio", + description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."), + io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."), + io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."), + ], + outputs=[ + io.Model.Output(), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: + # Encode reference audio to latents and patchify + audio_latents = audio_vae.encode(reference_audio) + b, c, t, f = audio_latents.shape + ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) + ref_audio = {"tokens": ref_tokens} + + positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio}) + negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio}) + + # Patch model with identity guidance + m = model.clone() + scale = identity_guidance_scale + model_sampling = m.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + def post_cfg_function(args): + if scale == 0: + return args["denoised"] + + sigma = args["sigma"] + sigma_ = sigma[0].item() + if sigma_ > sigma_start or sigma_ < sigma_end: + return args["denoised"] + + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + model_options = args["model_options"].copy() + x = args["input"] + + # Strip ref_audio from conditioning for the no-reference pass + noref_cond = [] + for entry in cond: + new_entry = entry.copy() + mc = new_entry.get("model_conds", {}).copy() + mc.pop("ref_audio", None) + new_entry["model_conds"] = mc + noref_cond.append(new_entry) + + (pred_noref,) = comfy.samplers.calc_cond_batch( + args["model"], [noref_cond], x, sigma, model_options + ) + + return cfg_result + (cond_pred - pred_noref) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return io.NodeOutput(m, positive, negative) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension): LTXVCropGuides, LTXVConcatAVLatent, LTXVSeparateAVLatent, + LTXVReferenceAudio, ] diff --git a/comfyui_version.py b/comfyui_version.py index a3b7204dc..61d7672ca 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.18.0" +__version__ = "0.18.1" diff --git a/main.py b/main.py index f99aee38e..cd4483c67 100644 --- a/main.py +++ b/main.py @@ -471,6 +471,9 @@ if __name__ == "__main__": if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + if args.disable_dynamic_vram: + logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.") + event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() diff --git a/manager_requirements.txt b/manager_requirements.txt index 5b06b56f6..90a2be84e 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b6 \ No newline at end of file +comfyui_manager==4.1b8 diff --git a/pyproject.toml b/pyproject.toml index 6db9b1267..1fc9402a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.18.0" +version = "0.18.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10"