From f719a9d928049f85b07b8ecc2259fba4832d37bb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Mar 2026 14:35:22 -0800 Subject: [PATCH 01/16] Adjust memory usage factor of zeta model. (#12746) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c0d3f387f..07feb31b3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1127,7 +1127,7 @@ class ZImagePixelSpace(ZImage): latent_format = latent_formats.ZImagePixelSpace # Much lower memory than latent-space models (no VAE, small patches). - memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) From b6ddc590ed8dafd50df8aad1e626b78276a690c0 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 3 Mar 2026 19:58:53 -0500 Subject: [PATCH 02/16] CURVE type (#12581) * CURVE type * fix: update typed wrapper unwrap keys to __type__ and __value__ * code improve * code improve --- comfy_api/latest/_io.py | 14 ++++++++++++++ execution.py | 14 ++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 189d7d9bc..050031dc0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO): return d +@comfytype(io_type="CURVE") +class Curve(ComfyTypeIO): + CurvePoint = tuple[float, float] + Type = list[CurvePoint] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [(0.0, 0.0), (1.0, 1.0)] + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2226,5 +2239,6 @@ __all__ = [ "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "Curve", "NodeReplace", ] diff --git a/execution.py b/execution.py index 75b021892..7ccdbf93e 100644 --- a/execution.py +++ b/execution.py @@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue else: try: - # Unwraps values wrapped in __value__ key. This is used to pass - # list widget value to execution, as by default list value is - # reserved to represent the connection between nodes. - if isinstance(val, dict) and "__value__" in val: - val = val["__value__"] - inputs[x] = val + # Unwraps values wrapped in __value__ key or typed wrapper. + # This is used to pass list widget values to execution, + # as by default list value is reserved to represent the + # connection between nodes. + if isinstance(val, dict): + if "__value__" in val: + val = val["__value__"] + inputs[x] = val if input_type == "INT": val = int(val) From ac6513e142f881202c40eacc5e337982b777ccd0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:19:40 -0800 Subject: [PATCH 03/16] DynamicVram: Add casting / fix torch Buffer weights (#12749) * respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it. --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 22 ++++++++++++++++++---- comfy/utils.py | 19 +++++++++++++++---- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672..0f5966371 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -796,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406b..70f78a089 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -309,7 +310,7 @@ class ModelPatcher: return comfy.model_management.get_free_memory(device) def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +337,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher): for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") From eb011733b6e4d8a9f7b67a1787d817bfc8c0a5b4 Mon Sep 17 00:00:00 2001 From: Arthur R Longbottom Date: Tue, 3 Mar 2026 21:29:00 -0800 Subject: [PATCH 04/16] Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683) * Fix VideoFromComponents.save_to crash when writing to BytesIO When `get_container_format()` or `get_stream_source()` is called on a tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`. Since BytesIO has no file extension, `av.open` can't infer the output format and throws `ValueError: Could not determine output format`. The sibling class `VideoFromFile` already handles this correctly via `get_open_write_kwargs()`, which detects BytesIO and sets the format explicitly. `VideoFromComponents` just never got the same treatment. This surfaces when any downstream node validates the container format of a tensor-based video, like TopazVideoEnhance or any node that calls `validate_container_format_is_mp4()`. Three-line fix in `comfy_api/latest/_input_impl/video_types.py`. * Add docstring to save_to to satisfy CI coverage check --- comfy_api/latest/_input_impl/video_types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a3d48c87f..58a37c9e8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: From d531e3fb2a885d675d5b6d3a496b4af5d9757af1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:47:44 -0800 Subject: [PATCH 05/16] model_patcher: Improve dynamic offload heuristic (#12759) Define a threshold below which a weight loading takes priority. This actually makes the offload consistent with non-dynamic, because what happens, is when non-dynamic fills ints to_load list, it will fill-up any left-over pieces that could fix large weights with small weights and load them, even though they were lower priority. This actually improves performance because the timy weights dont cost any VRAM and arent worth the control overhead of the DMA etc. --- comfy/model_patcher.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 70f78a089..168ce8430 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,7 +699,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -727,8 +727,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1508,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1627,9 +1632,9 @@ class ModelPatcherDynamic(ModelPatcher): return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return From 9b85cf955858b0aca6b7b30c30b404470ea0c964 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:49:13 -0800 Subject: [PATCH 06/16] Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754) * ops: dont unpin nothing This was calling into aimdo in the none case (offloaded weight). Whats worse, is aimdo syncs for unpinning an offloaded weight, as that is the corner case of a weight getting evicted by its own use which does require a sync. But this was heppening every offloaded weight causing slowdown. * mp: fix get_free_memory policy The ModelPatcherDynamic get_free_memory was deducting the model from to try and estimate the conceptual free memory with doing any offloading. This is kind of what the old memory_memory_required was estimating in ModelPatcher load logic, however in practical reality, between over-estimates and padding, the loader usually underloaded models enough such that sampling could send CFG +/- through together even when partially loaded. So don't regress from the status quo and instead go all in on the idea that offloading is less of an issue than debatching. Tell the sampler it can use everything. --- comfy/model_patcher.py | 14 +++++++------- comfy/ops.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 168ce8430..7e5ad7aa4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -307,7 +307,13 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) @@ -1465,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get() return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() - #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. def pin_weight_to_device(self, key): diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..8275dd0a5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: diff --git a/requirements.txt b/requirements.txt index 608b0cfa6..110568cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.5 requests #non essential dependencies: From 0a7446ade4bbeecfaf36e9a70eeabbeb0f6e59ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:59:56 +0200 Subject: [PATCH 07/16] Pass tokens when loading text gen model for text generation (#12755) Co-authored-by: Jedrzej Kosinski --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..8bcd09582 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) From 8811db52db5d0aea49c1dbedd733a6b9304b83a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:12:37 -0800 Subject: [PATCH 08/16] comfy-aimdo 0.2.6 (#12764) Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest as an error after a number of workflow runs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 110568cd3..dae46d873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.5 +comfy-aimdo>=0.2.6 requests #non essential dependencies: From ac4a943ff364885166def5d418582db971554caf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:33:14 -0800 Subject: [PATCH 09/16] Initial load device should be cpu when using dynamic vram. (#12766) --- comfy/model_management.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f5966371..809600815 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -830,11 +830,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -842,7 +845,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -945,6 +948,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device From 43c64b6308f93c331f057e12799bad0a68be5117 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:06:20 -0800 Subject: [PATCH 10/16] Support the LTXAV 2.3 model. (#12773) --- comfy/ldm/lightricks/av_model.py | 185 ++++++- comfy/ldm/lightricks/embeddings_connector.py | 4 + comfy/ldm/lightricks/model.py | 186 ++++++- comfy/ldm/lightricks/vae/audio_vae.py | 7 +- .../vae/causal_audio_autoencoder.py | 67 +-- .../vae/causal_video_autoencoder.py | 48 +- comfy/ldm/lightricks/vocoders/vocoder.py | 523 +++++++++++++++++- comfy/model_base.py | 2 +- comfy/sd.py | 2 +- comfy/text_encoders/lt.py | 68 ++- 10 files changed, 959 insertions(+), 133 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 553fd5b38..08d686b7b 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -2,11 +2,16 @@ from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, + NormSingleLinearTextProjection, LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module): v_context_dim=None, a_context_dim=None, attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=v_dim, @@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module): heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module): a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) - self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( - torch.empty(6, a_dim, device=device, dtype=dtype) + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) ) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) @@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module): vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vx.addcmul_(attn1_out, vgate_msa) del vgate_msa, attn1_out - vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) # audio if run_ax: @@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module): agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] ax.addcmul_(attn1_out, agate_msa) del agate_msa, attn1_out - ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) # video - audio cross attention. if run_a2v or run_v2a: @@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel): use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel): self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim @@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel): ) # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, @@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel): ) # Audio caption projection - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.audio_inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) - def preprocess_text_embeds(self, context): - if context.shape[-1] == self.caption_channels * 2: - return context - out_vid = self.video_embeddings_connector(context)[0] - out_audio = self.audio_embeddings_connector(context)[0] + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] return torch.concat((out_vid, out_audio), dim=-1) def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel): ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel): v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: @@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep_flat, + timestep.max().expand_as(a_timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep_flat, + a_timestep.max().expand_as(timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep_flat * av_ca_factor, + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep_flat * av_ca_factor, + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel): # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) else: a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] + a_prompt_timestep = None - return [v_timestep, a_timestep, cross_av_timestep_ss], [ + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ v_embedded_timestep, a_embedded_timestep, - ] + ], None def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + v_context, a_context = torch.split( - context, int(context.shape[-1] / 2), len(context.shape) - 1 + context, [v_context_dim, a_context_dim], len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) - if self.audio_caption_projection is not None: + if self.caption_proj_before_connector is False: a_context = self.audio_caption_projection(a_context) - a_context = a_context.view(batch_size, -1, ax.shape[-1]) + a_context = a_context.view(batch_size, -1, audio_dim) return [v_context, a_context], attention_mask @@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel): av_ca_v2a_gate_noise_timestep, ) = timestep[2] + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), ) return out @@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel): "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, }, {"original_block": block_wrap}, ) @@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 33adb9671..2811080be 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module): d_head, context_dim=None, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module): heads=n_heads, dim_head=d_head, context_dim=None, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module): positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module): num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 60d760d29..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -343,6 +367,7 @@ class CrossAttention(nn.Module): dim_head=64, dropout=0.0, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -362,6 +387,12 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None + self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -383,16 +414,30 @@ class CrossAttention(nn.Module): out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): def __init__( - self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, @@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module): operations=operations, ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' @@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC): vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, dtype=None, device=None, operations=None, @@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC): self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim @@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC): self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None + + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): @@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC): timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - return timestep, embedded_timestep + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" - if self.caption_projection is not None: + if self.caption_proj_before_connector is False: context = self.caption_projection(context) - context = context.view(batch_size, -1, x.shape[-1]) + context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( @@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC): merged_args.update(additional_args) # Prepare timestep and context - timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings @@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel): causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, - timestep_scale_multiplier = 1000.0, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel): def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" - # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel): self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel): pe=pe, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) return x diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 55a074661..fa0a00748 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) -from comfy.ldm.lightricks.vocoders.vocoder import Vocoder +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE LATENT_DOWNSAMPLE_FACTOR = 4 @@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module): vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) - self.vocoder = Vocoder(config=component_config.vocoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index f12b9bb53..b556b128f 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module): super().__init__() if config is None: - config = self._guess_config() + config = self.get_default_config() - # Extract encoder and decoder configs from the new format model_config = config.get("model", {}).get("params", {}) - variables_config = config.get("variables", {}) - self.sampling_rate = variables_config.get( - "sampling_rate", - model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) ) encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) decoder_config = model_config.get("decoder", encoder_config) # Load mel spectrogram parameters self.mel_bins = encoder_config.get("mel_bins", 64) - self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) - self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) # Store causality configuration at VAE level (not just in encoder internals) - causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) self.is_causal = self.causality_axis == CausalityAxis.HEIGHT @@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module): self.per_channel_statistics = processor() - def _guess_config(self): - encoder_config = { - # Required parameters - based on ltx-video-av-1679000 model metadata - "ch": 128, - "out_ch": 8, - "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] - "num_res_blocks": 2, - "attn_resolutions": [], # Based on metadata: empty list, no attention - "dropout": 0.0, - "resamp_with_conv": True, - "in_channels": 2, # stereo - "resolution": 256, - "z_channels": 8, + def get_default_config(self): + ddconfig = { "double_z": True, - "attn_type": "vanilla", - "mid_block_add_attention": False, # Based on metadata: false + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, "norm_type": "pixel", - "causality_axis": "height", # Based on metadata - "mel_bins": 64, # Based on metadata: mel_bins = 64 - } - - decoder_config = { - # Inherits encoder config, can override specific params - **encoder_config, - "out_ch": 2, # Stereo audio output (2 channels) - "give_pre_end": False, - "tanh_out": False, + "causality_axis": "height", } config = { - "_class_name": "CausalAudioAutoencoder", - "sampling_rate": 16000, "model": { "params": { - "encoder": encoder_config, - "decoder": decoder_config, + "ddconfig": ddconfig, + "sampling_rate": 16000, } }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, } return config diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index cbfdf412d..5b57dfc5e 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + def mark_conv3d_ended(module): tid = threading.get_ident() for _, m in module.named_modules(): @@ -350,6 +353,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -395,17 +402,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -455,6 +466,15 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: @@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) self.temporal_cache_state={} @@ -1012,9 +1041,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): + def decode(self, x): if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) - + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index b1f15f2c5..6c4028aa8 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import torch.nn as nn import comfy.ops import numpy as np +import math ops = comfy.ops.disable_weight_init @@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + b = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): @@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module): config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) - upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) - upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) - resblock = config.get("resblock", "1") + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 - self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], - "upsample_rates": [6, 5, 2, 2, 2], - "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, } return config @@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module): assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) + for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): @@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module): else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels - x = F.leaky_relu(x) + + x = self.act_post(x) x = self.conv_post(x) - x = torch.tanh(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/model_base.py b/comfy/model_base.py index 1e01e9edc..d9d5a9293 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1021,7 +1021,7 @@ class LTXAV(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/sd.py b/comfy/sd.py index 8bcd09582..888ef1e77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.LTXV: - clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e86ea9f4e..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + class LTXAVTEModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): super().__init__() self.dtypes = set() self.dtypes.add(dtype) self.compat_mode = False + self.text_projection_type = text_projection_type self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.dtypes.add(dtype_llama) operations = self.gemma3_12b.operations # TODO - self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + def enable_compat_mode(self): # TODO: remove from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module): out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) - out = out.movedim(1, -1).to(self.execution_device) - out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) - out = out.reshape((out.shape[0], out.shape[1], -1)) - out = self.text_embedding_projection(out) - out = out.float() - if self.compat_mode: - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) - return out.to(out_device), pooled + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) @@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) if len(sdo) == 0: sdo = sd @@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module): num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 -def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) return LTXAVTEModel_ + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): class Gemma3_12BModel_(Gemma3_12BModel): def __init__(self, device="cpu", dtype=None, model_options={}): From f2ee7f2d367f98bb8a33bcb4a224bda441eb8a07 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:21:55 -0800 Subject: [PATCH 11/16] Fix cublas ops on dynamic vram. (#12776) --- comfy/ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8275dd0a5..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -660,23 +660,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations From c5fe8ace68c432a262a5093bdd84b3ed70b9d283 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 15:37:35 +0800 Subject: [PATCH 12/16] chore: update workflow templates to v0.9.6 (#12778) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dae46d873..5f99407b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.5 +comfyui-workflow-templates==0.9.6 comfyui-embedded-docs==0.4.3 torch torchsde From 4941671b5a5c65fea48be922caa76b7f6a0a4595 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:39:51 -0800 Subject: [PATCH 13/16] Fix cuda getting initialized in cpu mode. (#12779) --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 809600815..ee28ea107 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1666,12 +1666,16 @@ def lora_compute_dtype(device): return dtype def synchronize(): + if cpu_mode(): + return if is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() From c8428541a6b6e4b1e0fbd685e9c846efcb60179e Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 16:58:25 +0800 Subject: [PATCH 14/16] chore: update workflow templates to v0.9.7 (#12780) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f99407b7..866818e08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.6 +comfyui-workflow-templates==0.9.7 comfyui-embedded-docs==0.4.3 torch torchsde From e04d0dbeb8266aa9262b5a4c3934ba4e4a371e37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 04:06:29 -0500 Subject: [PATCH 15/16] ComfyUI v0.16.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 6a35c6de3..0aea18d3a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.1" +__version__ = "0.16.0" diff --git a/pyproject.toml b/pyproject.toml index 1b2318273..f2133d99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.1" +version = "0.16.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From bd21363563ce8e312c9271a0c64a0145335df8a9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:29:39 +0200 Subject: [PATCH 16/16] feat(api-nodes-xAI): updated models, pricing, added features (#12756) --- comfy_api_nodes/apis/grok.py | 14 ++++-- comfy_api_nodes/nodes_grok.py | 92 +++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index 8e3c79ab9..c56c8aecc 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel): aspect_ratio: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + resolution: str = Field(...) class InputUrlObject(BaseModel): @@ -16,12 +17,13 @@ class InputUrlObject(BaseModel): class ImageEditRequest(BaseModel): model: str = Field(...) - image: InputUrlObject = Field(...) + images: list[InputUrlObject] = Field(...) prompt: str = Field(...) resolution: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + aspect_ratio: str | None = Field(...) class VideoGenerationRequest(BaseModel): @@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel): revised_prompt: str | None = Field(None) +class UsageObject(BaseModel): + cost_in_usd_ticks: int | None = Field(None) + + class ImageGenerationResponse(BaseModel): data: list[ImageResponseObject] = Field(...) + usage: UsageObject | None = Field(None) class VideoGenerationResponse(BaseModel): @@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel): status: str | None = Field(None) video: VideoResponseObject | None = Field(None) model: str | None = Field(None) + usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index da15e97ea..0716d6239 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -27,6 +27,12 @@ from comfy_api_nodes.util import ( ) +def _extract_grok_price(response) -> float | None: + if response.usage and response.usage.cost_in_usd_ticks is not None: + return response.usage.cost_in_usd_ticks / 10_000_000_000 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode): category="api node/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), IO.String.Input( "prompt", multiline=True, @@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), ], outputs=[ IO.Image.Output(), @@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": $rate * widgets.number_of_images} + ) + """, ), ) @@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio: str, number_of_images: int, seed: int, + resolution: str = "1K", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) response = await sync_op( @@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio=aspect_ratio, n=number_of_images, seed=seed, + resolution=resolution.lower(), ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode): category="api node/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), - IO.Image.Input("image"), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.Image.Input("image", display_name="images"), IO.String.Input( "prompt", multiline=True, tooltip="The text prompt used to generate the image", ), - IO.Combo.Input("resolution", options=["1K"]), + IO.Combo.Input("resolution", options=["1K", "2K"]), IO.Int.Input( "number_of_images", default=1, @@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + optional=True, + tooltip="Only allowed when multiple images are connected to the image input.", + ), ], outputs=[ IO.Image.Output(), @@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + ) + """, ), ) @@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode): resolution: str, number_of_images: int, seed: int, + aspect_ratio: str = "auto", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") + if model == "grok-imagine-image-pro": + if get_number_of_images(image) > 1: + raise ValueError("The pro model supports only 1 input image.") + elif get_number_of_images(image) > 3: + raise ValueError("A maximum of 3 input images is supported.") + if aspect_ratio != "auto" and get_number_of_images(image) == 1: + raise ValueError( + "Custom aspect ratio is only allowed when multiple images are connected to the image input." + ) response = await sync_op( cls, ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), data=ImageEditRequest( model=model, - image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], prompt=prompt, resolution=resolution.lower(), n=number_of_images, seed=seed, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode): category="api node/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), expr=""" ( - $base := 0.181 * widgets.duration; + $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $base := $rate * widgets.duration; {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} ) """, @@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode): category="api node/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", ), ) @@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url))