From 6555dc65b82c5f072dcad87f0dbccb4fc5f85e6b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Feb 2026 13:43:45 -0800 Subject: [PATCH 1/4] Make ace step 1.5 work without the llm. (#12311) --- comfy/ldm/ace/ace_step15.py | 72 +++++++++++++++++++++++++++++++++---- comfy/model_base.py | 17 +++------ 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index f2b130bc1..7fc7f1e8e 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -7,6 +7,67 @@ from comfy.ldm.modules.attention import optimized_attention import comfy.model_management from comfy.ldm.flux.layers import timestep_embedding +def get_silence_latent(length, device): + head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291, + -0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436, + -1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736, + 0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261, + 1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041, + 0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460, + -1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742, + -2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202], + [ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440, + 0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907, + -0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100, + -0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840, + 0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091, + 0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974, + -1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682, + -3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473], + [ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03, + 3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01, + -7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00, + 5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00, + 1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01, + -6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00, + 5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02, + -6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01, + 1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01, + 2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01, + -7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01, + -8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01, + 5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01], + [-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03, + 3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01, + -9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00, + 6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00, + 1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01, + -7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00, + 5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01, + -2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01, + 1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01, + 1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01, + -8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01, + -4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01, + 5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1) + + silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, + 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, + -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, + 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, + 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, + -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, + 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, + -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, + 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, + 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, + -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, + -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length) + silence_latent[:, :, :head.shape[-1]] = head + return silence_latent + + def get_layer_class(operations, layer_name): if operations is not None and hasattr(operations, layer_name): return getattr(operations, layer_name) @@ -1040,22 +1101,21 @@ class AceStepConditionGenerationModel(nn.Module): lm_hints = self.detokenizer(lm_hints_5Hz) lm_hints = lm_hints[:, :src_latents.shape[1], :] - if is_covers is None: + if is_covers is None or is_covers is True: src_latents = lm_hints - else: - src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents) + elif is_covers is False: + src_latents = refer_audio_acoustic_hidden_states_packed context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) return encoder_hidden, encoder_mask, context_latents - def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs): + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs): text_attention_mask = None lyric_attention_mask = None refer_audio_order_mask = None attention_mask = None chunk_masks = None - is_covers = None src_latents = None precomputed_lm_hints_25Hz = None lyric_hidden_states = lyric_embed @@ -1067,7 +1127,7 @@ class AceStepConditionGenerationModel(nn.Module): if refer_audio_order_mask is None: refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) - if src_latents is None and is_covers is None: + if src_latents is None: src_latents = x if chunk_masks is None: diff --git a/comfy/model_base.py b/comfy/model_base.py index a2a34f191..dcbf12074 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1560,22 +1560,11 @@ class ACEStep15(BaseModel): refer_audio = kwargs.get("reference_audio_timbre_latents", None) if refer_audio is None or len(refer_audio) == 0: - refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, - 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, - -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, - 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, - 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, - -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, - 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, - -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, - 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, - 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, - -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, - -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, - 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2]) + refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device) pass_audio_codes = True else: refer_audio = refer_audio[-1][:, :, :noise.shape[2]] + out['is_covers'] = comfy.conds.CONDConstant(True) pass_audio_codes = False if pass_audio_codes: @@ -1583,6 +1572,8 @@ class ACEStep15(BaseModel): if audio_codes is not None: out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) refer_audio = refer_audio[:, :, :750] + else: + out['is_covers'] = comfy.conds.CONDConstant(False) out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) return out From 458292fef0077470f5675ba52555e7bb4c28102e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:15:04 -0800 Subject: [PATCH 2/4] Fix some lowvram stuff with ace step 1.5 (#12312) --- comfy/ldm/ace/ace_step15.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 7fc7f1e8e..69338336d 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -738,7 +738,7 @@ class AttentionPooler(nn.Module): def forward(self, x): B, T, P, D = x.shape x = self.embed_tokens(x) - special = self.special_token.expand(B, T, 1, -1) + special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1) x = torch.cat([special, x], dim=2) x = x.view(B * T, P + 1, D) @@ -789,7 +789,7 @@ class FSQ(nn.Module): self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) def bound(self, z): - levels_minus_1 = (self._levels - 1).to(z.dtype) + levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1) scale = 2. / levels_minus_1 bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 @@ -804,8 +804,8 @@ class FSQ(nn.Module): return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. def codes_to_indices(self, zhat): - zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1)) - return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32) + zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1)) + return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32) def forward(self, z): orig_dtype = z.dtype @@ -887,7 +887,7 @@ class ResidualFSQ(nn.Module): x = self.project_in(x) if hasattr(self, 'soft_clamp_input_value'): - sc_val = self.soft_clamp_input_value.to(x.dtype) + sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype) x = (x / sc_val).tanh() * sc_val quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) @@ -895,7 +895,7 @@ class ResidualFSQ(nn.Module): all_indices = [] for layer, scale in zip(self.layers, self.scales): - scale = scale.to(residual.dtype) + scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype) quantized, indices = layer(residual / scale) quantized = quantized * scale From c2d7f07dbf312ef9034c65102f1a45c4a3355c1a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:24:09 -0800 Subject: [PATCH 3/4] Fix issue when using disable_unet_model_creation (#12315) --- comfy/model_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index dcbf12074..3bb54f59e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -147,11 +147,11 @@ class BaseModel(torch.nn.Module): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) - comfy.model_management.archive_model_dtypes(self.diffusion_model) - self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 From a1c101f861681ff18df5bdb0605e63c1ba9e8a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 6 Feb 2026 07:43:09 +0200 Subject: [PATCH 4/4] EasyCache: Support LTX2 (#12231) --- comfy_extras/nodes_easycache.py | 50 ++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 90d730df6..51d1e5b9c 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -9,6 +9,14 @@ if TYPE_CHECKING: from uuid import UUID +def _extract_tensor(data, output_channels): + """Extract tensor from data, handling both single tensors and lists.""" + if isinstance(data, list): + # LTX2 AV tensors: [video, audio] + return data[0][:, :output_channels], data[1][:, :output_channels] + return data[:, :output_channels], None + + def easycache_forward_wrapper(executor, *args, **kwargs): # get values from args transformer_options: dict[str] = args[-1] @@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if not transformer_options: transformer_options = args[-2] easycache: EasyCacheHolder = transformer_options["easycache"] - x: torch.Tensor = args[0][:, :easycache.output_channels] + x, ax = _extract_tensor(args[0], easycache.output_channels) sigmas = transformer_options["sigmas"] uuids = transformer_options["uuids"] if sigmas is not None and easycache.is_past_end_timestep(sigmas): @@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result if easycache.initial_step: easycache.first_cond_uuid = uuids[0] has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) @@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values easycache.skip_current_step = True - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result else: if easycache.verbose: logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") easycache.cumulative_change_rate = 0.0 - output: torch.Tensor = executor(*args, **kwargs) + full_output: torch.Tensor = executor(*args, **kwargs) + output, audio_output = _extract_tensor(full_output, easycache.output_channels) if has_first_cond_uuid and easycache.has_output_prev_norm(): output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() if easycache.verbose: @@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") # TODO: allow cache_diff to be offloaded easycache.update_cache_diff(output, next_x_prev, uuids) + if audio_output is not None: + easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True) if has_first_cond_uuid: easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) easycache.output_prev_subsampled = easycache.subsample(output, uuids) easycache.output_prev_norm = output.flatten().abs().mean() if easycache.verbose: logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") - return output + return full_output def lazycache_predict_noise_wrapper(executor, *args, **kwargs): # get values from args @@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs): easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) + x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels) # prepare next x_prev - x: torch.Tensor = args[0][:, :easycache.output_channels] next_x_prev = x input_change = None do_easycache = easycache.should_do_easycache(timestep) @@ -197,6 +216,7 @@ class EasyCacheHolder: self.output_prev_subsampled: torch.Tensor = None self.output_prev_norm: torch.Tensor = None self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {} self.output_change_rates = [] self.approx_output_change_rates = [] self.total_steps_skipped = 0 @@ -245,20 +265,21 @@ class EasyCacheHolder: def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: return all(uuid in self.uuid_cache_diffs for uuid in uuids) - def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): - if self.first_cond_uuid in uuids: + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + if self.first_cond_uuid in uuids and not is_audio: self.total_steps_skipped += 1 + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs batch_offset = x.shape[0] // len(uuids) for i, uuid in enumerate(uuids): # slice out only what is relevant to this cond batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) - if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if x.shape[1:] != cache_diffs[uuid].shape[1:]: if not self.allow_mismatch: raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") slicing = [] skip_this_dim = True - for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape): if skip_this_dim: skip_this_dim = False continue @@ -270,10 +291,11 @@ class EasyCacheHolder: else: slicing.append(slice(None)) batch_slice = batch_slice + slicing - x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) + x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device) return x - def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) if output.shape[1:] != x.shape[1:]: if not self.allow_mismatch: @@ -293,7 +315,7 @@ class EasyCacheHolder: diff = output - x batch_offset = diff.shape[0] // len(uuids) for i, uuid in enumerate(uuids): - self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: return self.first_cond_uuid in uuids @@ -324,6 +346,8 @@ class EasyCacheHolder: self.output_prev_norm = None del self.uuid_cache_diffs self.uuid_cache_diffs = {} + del self.uuid_cache_diffs_audio + self.uuid_cache_diffs_audio = {} self.total_steps_skipped = 0 self.state_metadata = None return self