From 3c1a1a2df82fc6c7aa20a3c8301d1c632e1a1d87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:06:18 -0800 Subject: [PATCH] Basic support for the ace step 1.5 model. (#12237) --- comfy/latent_formats.py | 4 + comfy/ldm/ace/ace_step15.py | 1093 ++++++++++++++++++++++++++++++++++ comfy/model_base.py | 42 ++ comfy/model_detection.py | 5 + comfy/sd.py | 21 +- comfy/sd1_clip.py | 2 + comfy/supported_models.py | 35 +- comfy/text_encoders/ace15.py | 218 +++++++ comfy/text_encoders/llama.py | 67 +++ comfy_extras/nodes_ace.py | 75 +++ comfy_extras/nodes_audio.py | 10 +- nodes.py | 2 +- 12 files changed, 1566 insertions(+), 8 deletions(-) create mode 100644 comfy/ldm/ace/ace_step15.py create mode 100644 comfy/text_encoders/ace15.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4b3a3798c..f59999af6 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -755,6 +755,10 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class ACEAudio15(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + class ChromaRadiance(LatentFormat): latent_channels = 3 spacial_downscale_ratio = 1 diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py new file mode 100644 index 000000000..d90549658 --- /dev/null +++ b/comfy/ldm/ace/ace_step15.py @@ -0,0 +1,1093 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import itertools +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management +from comfy.ldm.flux.layers import timestep_embedding + +def get_layer_class(operations, layer_name): + if operations is not None and hasattr(operations, layer_name): + return getattr(operations, layer_name) + return getattr(nn, layer_name) + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None): + super().__init__() + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len, x.device, x.dtype) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device), + self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device), + ) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device) + self.act1 = nn.SiLU() + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device) + self.in_channels = in_channels + self.act2 = nn.SiLU() + self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device) + self.scale = scale + + def forward(self, t, dtype=None): + t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale) + temb = self.linear_1(t_freq.to(dtype=dtype)) + temb = self.act1(temb) + temb = self.linear_2(temb) + timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1]) + return temb, timestep_proj + +class AceStepAttention(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rms_norm_eps=1e-6, + is_cross_attention=False, + sliding_window=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.is_cross_attention = is_cross_attention + self.sliding_window = sliding_window + + Linear = get_layer_class(operations, "Linear") + + self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device) + self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + position_embeddings=None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + query_states = query_states.transpose(1, 2) + + if self.is_cross_attention and encoder_hidden_states is not None: + bsz_enc, kv_len, _ = encoder_hidden_states.size() + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + + key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + else: + kv_len = q_len + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + n_rep = self.num_heads // self.num_kv_heads + if n_rep > 1: + key_states = key_states.repeat_interleave(n_rep, dim=1) + value_states = value_states.repeat_interleave(n_rep, dim=1) + + attn_bias = None + if self.sliding_window is not None and not self.is_cross_attention: + indices = torch.arange(q_len, device=query_states.device) + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + in_window = torch.abs(diff) <= self.sliding_window + + window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype) + min_value = torch.finfo(query_states.dtype).min + window_bias.masked_fill_(~in_window, min_value) + + window_bias = window_bias.unsqueeze(0).unsqueeze(0) + + if attn_bias is not None: + if attn_bias.dtype == torch.bool: + base_bias = torch.zeros_like(window_bias) + base_bias.masked_fill_(~attn_bias, min_value) + attn_bias = base_bias + window_bias + else: + attn_bias = attn_bias + window_bias + else: + attn_bias = window_bias + + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) + attn_output = self.o_proj(attn_output) + + return attn_output + +class AceStepDiTLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + layer_type="full_attention", + sliding_window=128, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self_attn_window = sliding_window if layer_type == "sliding_attention" else None + + self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, sliding_window=self_attn_window, + dtype=dtype, device=device, operations=operations + ) + + self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.cross_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=True, dtype=dtype, device=device, operations=operations + ) + + self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + temb, + encoder_hidden_states, + position_embeddings, + attention_mask=None, + encoder_attention_mask=None + ): + modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1) + + norm_hidden = self.self_attn_norm(hidden_states) + norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa + + attn_out = self.self_attn( + norm_hidden, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = hidden_states + attn_out * gate_msa + + norm_hidden = self.cross_attn_norm(hidden_states) + attn_out = self.cross_attn( + norm_hidden, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask + ) + hidden_states = hidden_states + attn_out + + norm_hidden = self.mlp_norm(hidden_states) + norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa + + mlp_out = self.mlp(norm_hidden) + hidden_states = hidden_states + mlp_out * c_gate_msa + + return hidden_states + +class AceStepEncoderLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, dtype=dtype, device=device, operations=operations + ) + self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + def forward(self, hidden_states, position_embeddings, attention_mask=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + +class AceStepLyricEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, inputs_embeds, attention_mask=None): + hidden_states = self.embed_tokens(inputs_embeds) + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + position_embeddings = (cos, sin) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + + hidden_states = self.norm(hidden_states) + return hidden_states + +class AceStepTimbreEncoder(nn.Module): + def __init__( + self, + timbre_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + + def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): + N, d = timbre_embs_packed.shape + device = timbre_embs_packed.device + B = N + counts = torch.bincount(refer_audio_order_mask, minlength=B) + max_count = counts.max().item() + + sorted_indices = torch.argsort( + refer_audio_order_mask * N + torch.arange(N, device=device), + stable=True + ) + sorted_batch_ids = refer_audio_order_mask[sorted_indices] + + positions = torch.arange(N, device=device) + batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) + positions_in_sorted = positions - batch_starts[sorted_batch_ids] + + inverse_indices = torch.empty_like(sorted_indices) + inverse_indices[sorted_indices] = torch.arange(N, device=device) + positions_in_batch = positions_in_sorted[inverse_indices] + + indices_2d = refer_audio_order_mask * max_count + positions_in_batch + one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype) + + timbre_embs_flat = one_hot.t() @ timbre_embs_packed + timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d) + + mask_flat = (one_hot.sum(dim=0) > 0).long() + new_mask = mask_flat.view(B, max_count) + return timbre_embs_unpack, new_mask + + def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None): + hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=(cos, sin), + attention_mask=attention_mask + ) + hidden_states = self.norm(hidden_states) + + flat_states = hidden_states[:, 0, :] + unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask) + return unpacked_embs, unpacked_mask + + +def pack_sequences(hidden1, hidden2, mask1, mask2): + hidden_cat = torch.cat([hidden1, hidden2], dim=1) + B, L, D = hidden_cat.shape + + if mask1 is not None and mask2 is not None: + mask_cat = torch.cat([mask1, mask2], dim=1) + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) + gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D) + hidden_sorted = torch.gather(hidden_cat, 1, gather_idx) + lengths = mask_cat.sum(dim=1) + new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) + else: + new_mask = None + hidden_sorted = hidden_cat + + return hidden_sorted, new_mask + +class AceStepConditionEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + timbre_hidden_dim, + hidden_size, + num_lyric_layers, + num_timbre_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.lyric_encoder = AceStepLyricEncoder( + text_hidden_dim=text_hidden_dim, + hidden_size=hidden_size, + num_layers=num_lyric_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + self.timbre_encoder = AceStepTimbreEncoder( + timbre_hidden_dim=timbre_hidden_dim, + hidden_size=hidden_size, + num_layers=num_timbre_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + def forward( + self, + text_hidden_states=None, + text_attention_mask=None, + lyric_hidden_states=None, + lyric_attention_mask=None, + refer_audio_acoustic_hidden_states_packed=None, + refer_audio_order_mask=None + ): + text_emb = self.text_projector(text_hidden_states) + + lyric_emb = self.lyric_encoder( + inputs_embeds=lyric_hidden_states, + attention_mask=lyric_attention_mask + ) + + timbre_emb, timbre_mask = self.timbre_encoder( + refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask + ) + + merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask) + final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask) + + return final_emb, final_mask + +# -------------------------------------------------------------------------------- +# Main Diffusion Model (DiT) +# -------------------------------------------------------------------------------- + +class AceStepDiTModel(nn.Module): + def __init__( + self, + in_channels, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + patch_size, + audio_acoustic_hidden_dim, + layer_types=None, + sliding_window=128, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.patch_size = patch_size + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + Conv1d = get_layer_class(operations, "Conv1d") + ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d") + Linear = get_layer_class(operations, "Linear") + + self.proj_in = nn.Sequential( + nn.Identity(), + Conv1d( + in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, + dtype=dtype, device=device)) + + self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + + if layer_types is None: + layer_types = ["full_attention"] * num_layers + + if len(layer_types) < num_layers: + layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers)) + + self.layers = nn.ModuleList([ + AceStepDiTLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + layer_type=layer_types[i], + sliding_window=sliding_window, + dtype=dtype, device=device, operations=operations + ) for i in range(num_layers) + ]) + + self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.proj_out = nn.Sequential( + nn.Identity(), + ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device) + ) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + timestep, + timestep_r, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + context_latents + ): + temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype) + temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype) + temb = temb_t + temb_r + timestep_proj = proj_t + proj_r + + x = torch.cat([context_latents, hidden_states], dim=-1) + original_seq_len = x.shape[1] + + pad_length = 0 + if x.shape[1] % self.patch_size != 0: + pad_length = self.patch_size - (x.shape[1] % self.patch_size) + x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0) + + x = x.transpose(1, 2) + x = self.proj_in(x) + x = x.transpose(1, 2) + + encoder_hidden_states = self.condition_embedder(encoder_hidden_states) + + seq_len = x.shape[1] + cos, sin = self.rotary_emb(x, seq_len=seq_len) + + for layer in self.layers: + x = layer( + hidden_states=x, + temb=timestep_proj, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=(cos, sin), + attention_mask=None, + encoder_attention_mask=None + ) + + shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + x = self.norm_out(x) * (1 + scale) + shift + + x = x.transpose(1, 2) + x = self.proj_out(x) + x = x.transpose(1, 2) + + x = x[:, :original_seq_len, :] + return x + + +class AttentionPooler(nn.Module): + def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, x): + B, T, P, D = x.shape + x = self.embed_tokens(x) + special = self.special_token.expand(B, T, 1, -1) + x = torch.cat([special, x], dim=2) + x = x.view(B * T, P + 1, D) + + cos, sin = self.rotary_emb(x, seq_len=P + 1) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + return x[:, 0, :].view(B, T, D) + + +class FSQ(nn.Module): + def __init__( + self, + levels, + dim=None, + device=None, + dtype=None, + operations=None + ): + super().__init__() + + _levels = torch.tensor(levels, dtype=torch.int32, device=device) + self.register_buffer('_levels', _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0) + self.register_buffer('_basis', _basis, persistent=False) + + self.codebook_dim = len(levels) + self.dim = dim if dim is not None else self.codebook_dim + + requires_projection = self.dim != self.codebook_dim + if requires_projection: + self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.codebook_size = self._levels.prod().item() + + indices = torch.arange(self.codebook_size, device=device) + implicit_codebook = self._indices_to_codes(indices) + + if dtype is not None: + implicit_codebook = implicit_codebook.to(dtype) + + self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) + + def bound(self, z): + levels_minus_1 = (self._levels - 1).to(z.dtype) + scale = 2. / levels_minus_1 + bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 + + zhat = bracket.floor() + bracket_ste = bracket + (zhat - bracket).detach() + + return scale * bracket_ste - 1. + + def _indices_to_codes(self, indices): + indices = indices.unsqueeze(-1) + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. + + def codes_to_indices(self, zhat): + zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1)) + return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32) + + def forward(self, z): + orig_dtype = z.dtype + z = self.project_in(z) + + codes = self.bound(z) + indices = self.codes_to_indices(codes) + + out = self.project_out(codes) + return out.to(orig_dtype), indices + + +class ResidualFSQ(nn.Module): + def __init__( + self, + levels, + num_quantizers, + dim=None, + bound_hard_clamp=True, + device=None, + dtype=None, + operations=None, + **kwargs + ): + super().__init__() + + codebook_dim = len(levels) + dim = dim if dim is not None else codebook_dim + + requires_projection = codebook_dim != dim + if requires_projection: + self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.layers = nn.ModuleList() + levels_tensor = torch.tensor(levels, device=device) + scales = [] + + for ind in range(num_quantizers): + scale_val = levels_tensor.float() ** -ind + scales.append(scale_val) + + self.layers.append(FSQ( + levels=levels, + dim=codebook_dim, + device=device, + dtype=dtype, + operations=operations + )) + + scales_tensor = torch.stack(scales) + if dtype is not None: + scales_tensor = scales_tensor.to(dtype) + self.register_buffer('scales', scales_tensor, persistent=False) + + if bound_hard_clamp: + val = 1 + (1 / (levels_tensor.float() - 1)) + if dtype is not None: + val = val.to(dtype) + self.register_buffer('soft_clamp_input_value', val, persistent=False) + + def get_output_from_indices(self, indices, dtype=torch.float32): + if indices.dim() == 2: + indices = indices.unsqueeze(-1) + + all_codes = [] + for i, layer in enumerate(self.layers): + idx = indices[..., i].long() + codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype)) + all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype)) + + codes_summed = torch.stack(all_codes).sum(dim=0) + return self.project_out(codes_summed) + + def forward(self, x): + x = self.project_in(x) + + if hasattr(self, 'soft_clamp_input_value'): + sc_val = self.soft_clamp_input_value.to(x.dtype) + x = (x / sc_val).tanh() * sc_val + + quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) + residual = x + all_indices = [] + + for layer, scale in zip(self.layers, self.scales): + scale = scale.to(residual.dtype) + + quantized, indices = layer(residual / scale) + quantized = quantized * scale + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + + quantized_out = self.project_out(quantized_out) + all_indices = torch.stack(all_indices, dim=-1) + + return quantized_out, all_indices + + +class AceStepAudioTokenizer(nn.Module): + def __init__( + self, + audio_acoustic_hidden_dim, + hidden_size, + pool_window_size, + fsq_dim, + fsq_levels, + fsq_input_num_quantizers, + num_layers, + head_dim, + rms_norm_eps, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device) + self.attention_pooler = AttentionPooler( + hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations + ) + self.pool_window_size = pool_window_size + self.fsq_dim = fsq_dim + self.quantizer = ResidualFSQ( + dim=fsq_dim, + levels=fsq_levels, + num_quantizers=fsq_input_num_quantizers, + bound_hard_clamp=True, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, hidden_states): + hidden_states = self.audio_acoustic_proj(hidden_states) + hidden_states = self.attention_pooler(hidden_states) + quantized, indices = self.quantizer(hidden_states) + return quantized, indices + + def tokenize(self, x): + B, T, D = x.shape + P = self.pool_window_size + + if T % P != 0: + pad = P - (T % P) + x = F.pad(x, (0, 0, 0, pad)) + T = x.shape[1] + + T_patch = T // P + x = x.view(B, T_patch, P, D) + + quantized, indices = self.forward(x) + return quantized, indices + + +class AudioTokenDetokenizer(nn.Module): + def __init__( + self, + hidden_size, + pool_window_size, + audio_acoustic_hidden_dim, + num_layers, + head_dim, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.pool_window_size = pool_window_size + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device)) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) + self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device) + + def forward(self, x): + B, T, D = x.shape + x = self.embed_tokens(x) + x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1) + x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype) + x = x.view(B * T, self.pool_window_size, D) + + cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + x = self.proj_out(x) + return x.view(B, T * self.pool_window_size, -1) + + +class AceStepConditionGenerationModel(nn.Module): + def __init__( + self, + in_channels=192, + hidden_size=2048, + text_hidden_dim=1024, + timbre_hidden_dim=64, + audio_acoustic_hidden_dim=64, + num_dit_layers=24, + num_lyric_layers=8, + num_timbre_layers=4, + num_tokenizer_layers=2, + num_heads=16, + num_kv_heads=8, + head_dim=128, + intermediate_size=6144, + patch_size=2, + pool_window_size=5, + rms_norm_eps=1e-06, + timestep_mu=-0.4, + timestep_sigma=1.0, + data_proportion=0.5, + sliding_window=128, + layer_types=None, + fsq_dim=2048, + fsq_levels=[8, 8, 8, 5, 5, 5], + fsq_input_num_quantizers=1, + audio_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + self.timestep_mu = timestep_mu + self.timestep_sigma = timestep_sigma + self.data_proportion = data_proportion + self.pool_window_size = pool_window_size + + if layer_types is None: + layer_types = [] + for i in range(num_dit_layers): + layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention") + + self.decoder = AceStepDiTModel( + in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim, + intermediate_size, patch_size, audio_acoustic_hidden_dim, + layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.encoder = AceStepConditionEncoder( + text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, + num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.tokenizer = AceStepAudioTokenizer( + audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.detokenizer = AudioTokenDetokenizer( + hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, + dtype=dtype, device=device, operations=operations + ) + self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + def prepare_condition( + self, + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, + precomputed_lm_hints_25Hz=None, + audio_codes=None + ): + encoder_hidden, encoder_mask = self.encoder( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask + ) + + if precomputed_lm_hints_25Hz is not None: + lm_hints = precomputed_lm_hints_25Hz + else: + if audio_codes is not None: + if audio_codes.shape[1] * 5 < src_latents.shape[1]: + audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) + lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) + else: + assert False + # TODO ? + + lm_hints = self.detokenizer(lm_hints_5Hz) + + lm_hints = lm_hints[:, :src_latents.shape[1], :] + if is_covers is None: + src_latents = lm_hints + else: + src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents) + + context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) + + return encoder_hidden, encoder_mask, context_latents + + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs): + text_attention_mask = None + lyric_attention_mask = None + refer_audio_order_mask = None + attention_mask = None + chunk_masks = None + is_covers = None + src_latents = None + precomputed_lm_hints_25Hz = None + lyric_hidden_states = lyric_embed + text_hidden_states = context + refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2) + + x = x.movedim(-1, -2) + + if refer_audio_order_mask is None: + refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) + + if src_latents is None and is_covers is None: + src_latents = x + + if chunk_masks is None: + chunk_masks = torch.ones_like(x) + + enc_hidden, enc_mask, context_latents = self.prepare_condition( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes + ) + + out = self.decoder(hidden_states=x, + timestep=timestep, + timestep_r=timestep, + attention_mask=attention_mask, + encoder_hidden_states=enc_hidden, + encoder_attention_mask=enc_mask, + context_latents=context_latents + ) + + return out.movedim(-1, -2) diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..89944548c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.ace.ace_step15 import comfy.model_management import comfy.patcher_extension @@ -1540,6 +1541,47 @@ class ACEStep(BaseModel): out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out +class ACEStep15(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + device = kwargs["device"] + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_lyrics = kwargs.get("conditioning_lyrics", None) + if cross_attn is not None: + out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics) + + refer_audio = kwargs.get("reference_audio_timbre_latents", None) + if refer_audio is None or len(refer_audio) == 0: + refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, + 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, + -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, + 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, + 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, + -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, + 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, + -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, + 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, + 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, + -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, + -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750) + else: + refer_audio = refer_audio[-1] + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + + return out + class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8cea16e50..e8ad725df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -655,6 +655,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["audio_model"] = "ace1.5" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..722c0c154 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie import comfy.text_encoders.anima +import comfy.text_encoders.ace15 import comfy.model_patcher import comfy.lora @@ -452,6 +453,8 @@ class VAE: self.extra_1d_channel = None self.crop_input = True + self.audio_sample_rate = 44100 + if config is None: if "decoder.mid.block_1.mix_factor" in sd: encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -549,14 +552,25 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: - self.first_stage_model = AudioOobleckVAE() + config = {} + param_key = None + if "decoder.layers.2.layers.1.weight_v" in sd: + param_key = "decoder.layers.2.layers.1.weight_v" + if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd: + param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1" + if param_key is not None: + if sd[param_key].shape[-1] == 12: + config["strides"] = [2, 4, 4, 6, 10] + self.audio_sample_rate = 48000 + + self.first_stage_model = AudioOobleckVAE(**config) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 self.pad_channel_value = "replicate" self.upscale_ratio = 2048 - self.downscale_ratio = 2048 + self.downscale_ratio = 2048 self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio @@ -1427,6 +1441,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data_jina = clip_data[0] tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) + elif clip_type == CLIPType.ACE: + clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9ecfc9c55..4c817d468 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass + elif isinstance(layer_idx, list): + self.layer = layer_idx elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..6b7d831cb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.anima +import comfy.text_encoders.ace15 from . import supported_models_base from . import latent_formats @@ -1596,6 +1597,38 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +class ACEStep15(supported_models_base.BASE): + unet_config = { + "audio_model": "ace1.5", + } + + unet_extra_config = { + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + latent_format = comfy.latent_formats.ACEAudio15 + + memory_usage_factor = 4.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.ACEStep15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py new file mode 100644 index 000000000..9070cb577 --- /dev/null +++ b/comfy/text_encoders/ace15.py @@ -0,0 +1,218 @@ +from .anima import Qwen3Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import torch +import math + + +def sample_manual_loop_no_classes( + model, + ids=None, + paddings=[], + execution_dtype=None, + cfg_scale: float = 2.0, + temperature: float = 0.85, + top_p: float = 0.9, + top_k: int = None, + seed: int = 1, + min_tokens: int = 1, + max_new_tokens: int = 2048, + audio_start_id: int = 151669, # The cutoff ID for audio codes + eos_token_id: int = 151645, +): + device = model.execution_device + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + + embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) + for i, t in enumerate(paddings): + attention_mask[i, :t] = 0 + attention_mask[i, t:] = 1 + + output_audio_codes = [] + past_key_values = [] + generator = torch.Generator(device=device) + generator.manual_seed(seed) + model_config = model.transformer.model.config + + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + + for step in range(max_new_tokens): + outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) + next_token_logits = model.transformer.logits(outputs[0])[:, -1] + past_key_values = outputs[2] + + cond_logits = next_token_logits[0:1] + uncond_logits = next_token_logits[1:2] + cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) + + if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + eos_score = cfg_logits[:, eos_token_id].clone() + + # Only generate audio tokens + cfg_logits[:, :audio_start_id] = float('-inf') + + if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + cfg_logits[:, eos_token_id] = eos_score + + if top_k is not None and top_k > 0: + top_k_vals, _ = torch.topk(cfg_logits, top_k) + min_val = top_k_vals[..., -1, None] + cfg_logits[cfg_logits < min_val] = float('-inf') + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + cfg_logits[indices_to_remove] = float('-inf') + + if temperature > 0: + cfg_logits = cfg_logits / temperature + next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1) + else: + next_token = torch.argmax(cfg_logits, dim=-1) + + token = next_token.item() + + if token == eos_token_id: + break + + embed, _, _, _ = model.process_tokens([[token]], device) + embeds = embed.repeat(2, 1, 1) + attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1) + + output_audio_codes.append(token - audio_start_id) + + return output_audio_codes + + +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0): + cfg_scale = 2.0 + + positive = [[token for token, _ in inner_list] for inner_list in positive] + negative = [[token for token, _ in inner_list] for inner_list in negative] + positive = positive[0] + negative = negative[0] + + neg_pad = 0 + if len(negative) < len(positive): + neg_pad = (len(positive) - len(negative)) + negative = [model.special_tokens["pad"]] * neg_pad + negative + + pos_pad = 0 + if len(negative) > len(positive): + pos_pad = (len(negative) - len(positive)) + positive = [model.special_tokens["pad"]] * pos_pad + positive + + paddings = [pos_pad, neg_pad] + return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + + +class ACE15Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + out = {} + lyrics = kwargs.get("lyrics", "") + bpm = kwargs.get("bpm", 120) + duration = kwargs.get("duration", 120) + keyscale = kwargs.get("keyscale", "C major") + timesignature = kwargs.get("timesignature", 2) + language = kwargs.get("language", "en") + seed = kwargs.get("seed", 0) + + duration = math.ceil(duration) + meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" + + meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration) + out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True) + out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True) + + out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) + out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) + out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} + return out + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_2B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class ACE15TEModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}): + super().__init__() + if dtype_llama is None: + dtype_llama = dtype + + self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) + self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options) + self.dtypes = set([dtype, dtype_llama]) + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_base = token_weight_pairs["qwen3_06b"] + token_weight_pairs_lyrics = token_weight_pairs["lyrics"] + + self.qwen3_06b.set_clip_options({"layer": None}) + base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base) + self.qwen3_06b.set_clip_options({"layer": [0]}) + lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) + + lm_metadata = token_weight_pairs["lm_metadata"] + audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + + return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} + + def set_clip_options(self, options): + self.qwen3_06b.set_clip_options(options) + self.qwen3_2b.set_clip_options(options) + + def reset_clip_options(self): + self.qwen3_06b.reset_clip_options() + self.qwen3_2b.reset_clip_options() + + def load_sd(self, sd): + if "model.layers.0.post_attention_layernorm.weight" in sd: + shape = sd["model.layers.0.post_attention_layernorm.weight"].shape + if shape[0] == 1024: + return self.qwen3_06b.load_sd(sd) + else: + return self.qwen3_2b.load_sd(sd) + + def memory_estimation_function(self, token_weight_pairs, device=None): + lm_metadata = token_weight_pairs["lm_metadata"] + constant = 0.4375 + if comfy.model_management.should_use_bf16(device): + constant *= 0.5 + + token_weight_pairs = token_weight_pairs.get("lm_prompt", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + num_tokens += lm_metadata['min_tokens'] + return num_tokens * constant * 1024 * 1024 + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ACE15TEModel_(ACE15TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options) + return ACE15TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 68ac1e804..d2324ffc5 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -103,6 +103,52 @@ class Qwen3_06BConfig: final_norm: bool = True lm_head: bool = False +@dataclass +class Qwen3_06B_ACE15_Config: + vocab_size: int = 151669 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + +@dataclass +class Qwen3_2B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -729,6 +775,27 @@ class Qwen3_06B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06B_ACE15_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_2B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + def logits(self, x): + return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None) + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 1409233c9..376584e5c 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -28,12 +28,39 @@ class TextEncodeAceStepAudio(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) return io.NodeOutput(conditioning) +class TextEncodeAceStepAudio15(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio1.5", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Int.Input("bpm", default=120, min=10, max=300), + io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + io.Combo.Input("timesignature", options=['2', '3', '4', '6']), + io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return io.NodeOutput(conditioning) + class EmptyAceStepLatentAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="EmptyAceStepLatentAudio", + display_name="Empty Ace Step 1.0 Latent Audio", category="latent/audio", inputs=[ io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), @@ -51,12 +78,60 @@ class EmptyAceStepLatentAudio(io.ComfyNode): return io.NodeOutput({"samples": latent, "type": "audio"}) +class EmptyAceStep15LatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStep1.5LatentAudio", + display_name="Empty Ace Step 1.5 Latent Audio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> io.NodeOutput: + length = round((seconds * 48000 / 1920)) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) + +class ReferenceTimbreAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReferenceTimbreAudio", + category="advanced/conditioning/audio", + is_experimental=True, + description="This node sets the reference audio for timbre (for ace step 1.5)", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: + if latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) + return io.NodeOutput(conditioning) + class AceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeAceStepAudio, EmptyAceStepLatentAudio, + TextEncodeAceStepAudio15, + EmptyAceStep15LatentAudio, + ReferenceTimbreAudio, ] async def comfy_entrypoint() -> AceExtension: diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 271b75fbd..bef723dce 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -82,13 +82,14 @@ class VAEEncodeAudio(IO.ComfyNode): @classmethod def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] - if 44100 != sample_rate: - waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return IO.NodeOutput({"samples":t}) + return IO.NodeOutput({"samples": t}) encode = execute # TODO: remove @@ -114,7 +115,8 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/nodes.py b/nodes.py index 1cb43d9e2..e11a8ed80 100644 --- a/nodes.py +++ b/nodes.py @@ -1001,7 +1001,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),