Merge branch 'master' into feat/core/expected_outputs
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Alexander Piskun 2026-02-06 09:01:47 +02:00 committed by GitHub
commit 3902c86e83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 115 additions and 40 deletions

View File

@ -7,6 +7,67 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
from comfy.ldm.flux.layers import timestep_embedding
def get_silence_latent(length, device):
head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291,
-0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436,
-1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736,
0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261,
1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041,
0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460,
-1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742,
-2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202],
[ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440,
0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907,
-0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100,
-0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840,
0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091,
0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974,
-1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682,
-3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473],
[ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03,
3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01,
-7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00,
5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00,
1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01,
-6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00,
5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02,
-6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01,
1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01,
2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01,
-7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01,
-8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01,
5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01],
[-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03,
3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01,
-9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00,
6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00,
1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01,
-7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00,
5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01,
-2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01,
1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01,
1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01,
-8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01,
-4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01,
5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1)
silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length)
silence_latent[:, :, :head.shape[-1]] = head
return silence_latent
def get_layer_class(operations, layer_name):
if operations is not None and hasattr(operations, layer_name):
return getattr(operations, layer_name)
@ -677,7 +738,7 @@ class AttentionPooler(nn.Module):
def forward(self, x):
B, T, P, D = x.shape
x = self.embed_tokens(x)
special = self.special_token.expand(B, T, 1, -1)
special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1)
x = torch.cat([special, x], dim=2)
x = x.view(B * T, P + 1, D)
@ -728,7 +789,7 @@ class FSQ(nn.Module):
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
def bound(self, z):
levels_minus_1 = (self._levels - 1).to(z.dtype)
levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1)
scale = 2. / levels_minus_1
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
@ -743,8 +804,8 @@ class FSQ(nn.Module):
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
def codes_to_indices(self, zhat):
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1))
return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32)
def forward(self, z):
orig_dtype = z.dtype
@ -826,7 +887,7 @@ class ResidualFSQ(nn.Module):
x = self.project_in(x)
if hasattr(self, 'soft_clamp_input_value'):
sc_val = self.soft_clamp_input_value.to(x.dtype)
sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype)
x = (x / sc_val).tanh() * sc_val
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
@ -834,7 +895,7 @@ class ResidualFSQ(nn.Module):
all_indices = []
for layer, scale in zip(self.layers, self.scales):
scale = scale.to(residual.dtype)
scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype)
quantized, indices = layer(residual / scale)
quantized = quantized * scale
@ -1040,22 +1101,21 @@ class AceStepConditionGenerationModel(nn.Module):
lm_hints = self.detokenizer(lm_hints_5Hz)
lm_hints = lm_hints[:, :src_latents.shape[1], :]
if is_covers is None:
if is_covers is None or is_covers is True:
src_latents = lm_hints
else:
src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
elif is_covers is False:
src_latents = refer_audio_acoustic_hidden_states_packed
context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
return encoder_hidden, encoder_mask, context_latents
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
text_attention_mask = None
lyric_attention_mask = None
refer_audio_order_mask = None
attention_mask = None
chunk_masks = None
is_covers = None
src_latents = None
precomputed_lm_hints_25Hz = None
lyric_hidden_states = lyric_embed
@ -1067,7 +1127,7 @@ class AceStepConditionGenerationModel(nn.Module):
if refer_audio_order_mask is None:
refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
if src_latents is None and is_covers is None:
if src_latents is None:
src_latents = x
if chunk_masks is None:

View File

@ -147,11 +147,11 @@ class BaseModel(torch.nn.Module):
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
@ -1560,22 +1560,11 @@ class ACEStep15(BaseModel):
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
if refer_audio is None or len(refer_audio) == 0:
refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
pass_audio_codes = True
else:
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
out['is_covers'] = comfy.conds.CONDConstant(True)
pass_audio_codes = False
if pass_audio_codes:
@ -1583,6 +1572,8 @@ class ACEStep15(BaseModel):
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
refer_audio = refer_audio[:, :, :750]
else:
out['is_covers'] = comfy.conds.CONDConstant(False)
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
return out

View File

@ -9,6 +9,14 @@ if TYPE_CHECKING:
from uuid import UUID
def _extract_tensor(data, output_channels):
"""Extract tensor from data, handling both single tensors and lists."""
if isinstance(data, list):
# LTX2 AV tensors: [video, audio]
return data[0][:, :output_channels], data[1][:, :output_channels]
return data[:, :output_channels], None
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
transformer_options: dict[str] = args[-1]
@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
x: torch.Tensor = args[0][:, :easycache.output_channels]
x, ax = _extract_tensor(args[0], easycache.output_channels)
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if easycache.skip_current_step and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
if easycache.initial_step:
easycache.first_cond_uuid = uuids[0]
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x, uuids)
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
else:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
full_output: torch.Tensor = executor(*args, **kwargs)
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
if has_first_cond_uuid and easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev, uuids)
if audio_output is not None:
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
if has_first_cond_uuid:
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
return full_output
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
# prepare next x_prev
x: torch.Tensor = args[0][:, :easycache.output_channels]
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
@ -197,6 +216,7 @@ class EasyCacheHolder:
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
@ -245,20 +265,21 @@ class EasyCacheHolder:
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
if self.first_cond_uuid in uuids and not is_audio:
self.total_steps_skipped += 1
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# slice out only what is relevant to this cond
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_this_dim = True
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
if skip_this_dim:
skip_this_dim = False
continue
@ -270,10 +291,11 @@ class EasyCacheHolder:
else:
slicing.append(slice(None))
batch_slice = batch_slice + slicing
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if output.shape[1:] != x.shape[1:]:
if not self.allow_mismatch:
@ -293,7 +315,7 @@ class EasyCacheHolder:
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
@ -324,6 +346,8 @@ class EasyCacheHolder:
self.output_prev_norm = None
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
del self.uuid_cache_diffs_audio
self.uuid_cache_diffs_audio = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self