import math import torch import torch.nn as nn import torch.nn.functional as F import itertools from comfy.ldm.modules.attention import optimized_attention import comfy.model_management from comfy.ldm.flux.layers import timestep_embedding def get_layer_class(operations, layer_name): if operations is not None and hasattr(operations, layer_name): return getattr(operations, layer_name) return getattr(nn, layer_name) class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None): super().__init__() self.dim = dim self.base = base self.max_position_embeddings = max_position_embeddings inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len, x.device, x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device), self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device), ) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class MLP(nn.Module): def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None): super().__init__() Linear = get_layer_class(operations, "Linear") self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) self.act_fn = nn.SiLU() def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class TimestepEmbedding(nn.Module): def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None): super().__init__() Linear = get_layer_class(operations, "Linear") self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device) self.act1 = nn.SiLU() self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device) self.in_channels = in_channels self.act2 = nn.SiLU() self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device) self.scale = scale def forward(self, t, dtype=None): t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale) temb = self.linear_1(t_freq.to(dtype=dtype)) temb = self.act1(temb) temb = self.linear_2(temb) timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1]) return temb, timestep_proj class AceStepAttention(nn.Module): def __init__( self, hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps=1e-6, is_cross_attention=False, sliding_window=None, dtype=None, device=None, operations=None ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.is_cross_attention = is_cross_attention self.sliding_window = sliding_window Linear = get_layer_class(operations, "Linear") self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device) self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device) self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) def forward( self, hidden_states, encoder_hidden_states=None, attention_mask=None, position_embeddings=None, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) query_states = self.q_norm(query_states) query_states = query_states.transpose(1, 2) if self.is_cross_attention and encoder_hidden_states is not None: bsz_enc, kv_len, _ = encoder_hidden_states.size() key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) key_states = self.k_norm(key_states) value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) else: kv_len = q_len key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) key_states = self.k_norm(key_states) value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) n_rep = self.num_heads // self.num_kv_heads if n_rep > 1: key_states = key_states.repeat_interleave(n_rep, dim=1) value_states = value_states.repeat_interleave(n_rep, dim=1) attn_bias = None if self.sliding_window is not None and not self.is_cross_attention: indices = torch.arange(q_len, device=query_states.device) diff = indices.unsqueeze(1) - indices.unsqueeze(0) in_window = torch.abs(diff) <= self.sliding_window window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype) min_value = torch.finfo(query_states.dtype).min window_bias.masked_fill_(~in_window, min_value) window_bias = window_bias.unsqueeze(0).unsqueeze(0) if attn_bias is not None: if attn_bias.dtype == torch.bool: base_bias = torch.zeros_like(window_bias) base_bias.masked_fill_(~attn_bias, min_value) attn_bias = base_bias + window_bias else: attn_bias = attn_bias + window_bias else: attn_bias = window_bias attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) attn_output = self.o_proj(attn_output) return attn_output class AceStepDiTLayer(nn.Module): def __init__( self, hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps=1e-6, layer_type="full_attention", sliding_window=128, dtype=None, device=None, operations=None ): super().__init__() self_attn_window = sliding_window if layer_type == "sliding_attention" else None self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.self_attn = AceStepAttention( hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, is_cross_attention=False, sliding_window=self_attn_window, dtype=dtype, device=device, operations=operations ) self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.cross_attn = AceStepAttention( hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, is_cross_attention=True, dtype=dtype, device=device, operations=operations ) self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device)) def forward( self, hidden_states, temb, encoder_hidden_states, position_embeddings, attention_mask=None, encoder_attention_mask=None ): modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1) norm_hidden = self.self_attn_norm(hidden_states) norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa attn_out = self.self_attn( norm_hidden, position_embeddings=position_embeddings, attention_mask=attention_mask ) hidden_states = hidden_states + attn_out * gate_msa norm_hidden = self.cross_attn_norm(hidden_states) attn_out = self.cross_attn( norm_hidden, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask ) hidden_states = hidden_states + attn_out norm_hidden = self.mlp_norm(hidden_states) norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa mlp_out = self.mlp(norm_hidden) hidden_states = hidden_states + mlp_out * c_gate_msa return hidden_states class AceStepEncoderLayer(nn.Module): def __init__( self, hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps=1e-6, dtype=None, device=None, operations=None ): super().__init__() self.self_attn = AceStepAttention( hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, is_cross_attention=False, dtype=dtype, device=device, operations=operations ) self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) def forward(self, hidden_states, position_embeddings, attention_mask=None): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class AceStepLyricEncoder(nn.Module): def __init__( self, text_hidden_dim, hidden_size, num_layers, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps=1e-6, dtype=None, device=None, operations=None ): super().__init__() Linear = get_layer_class(operations, "Linear") self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device) self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.rotary_emb = RotaryEmbedding( head_dim, base=1000000.0, dtype=dtype, device=device, operations=operations ) self.layers = nn.ModuleList([ AceStepEncoderLayer( hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, dtype=dtype, device=device, operations=operations ) for _ in range(num_layers) ]) def forward(self, inputs_embeds, attention_mask=None): hidden_states = self.embed_tokens(inputs_embeds) seq_len = hidden_states.shape[1] cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) position_embeddings = (cos, sin) for layer in self.layers: hidden_states = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask ) hidden_states = self.norm(hidden_states) return hidden_states class AceStepTimbreEncoder(nn.Module): def __init__( self, timbre_hidden_dim, hidden_size, num_layers, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps=1e-6, dtype=None, device=None, operations=None ): super().__init__() Linear = get_layer_class(operations, "Linear") self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device) self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.rotary_emb = RotaryEmbedding( head_dim, base=1000000.0, dtype=dtype, device=device, operations=operations ) self.layers = nn.ModuleList([ AceStepEncoderLayer( hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, dtype=dtype, device=device, operations=operations ) for _ in range(num_layers) ]) self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): N, d = timbre_embs_packed.shape device = timbre_embs_packed.device B = N counts = torch.bincount(refer_audio_order_mask, minlength=B) max_count = counts.max().item() sorted_indices = torch.argsort( refer_audio_order_mask * N + torch.arange(N, device=device), stable=True ) sorted_batch_ids = refer_audio_order_mask[sorted_indices] positions = torch.arange(N, device=device) batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) positions_in_sorted = positions - batch_starts[sorted_batch_ids] inverse_indices = torch.empty_like(sorted_indices) inverse_indices[sorted_indices] = torch.arange(N, device=device) positions_in_batch = positions_in_sorted[inverse_indices] indices_2d = refer_audio_order_mask * max_count + positions_in_batch one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype) timbre_embs_flat = one_hot.t() @ timbre_embs_packed timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d) mask_flat = (one_hot.sum(dim=0) > 0).long() new_mask = mask_flat.view(B, max_count) return timbre_embs_unpack, new_mask def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None): hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) if hidden_states.dim() == 2: hidden_states = hidden_states.unsqueeze(0) seq_len = hidden_states.shape[1] cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) for layer in self.layers: hidden_states = layer( hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask ) hidden_states = self.norm(hidden_states) flat_states = hidden_states[:, 0, :] unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask) return unpacked_embs, unpacked_mask def pack_sequences(hidden1, hidden2, mask1, mask2): hidden_cat = torch.cat([hidden1, hidden2], dim=1) B, L, D = hidden_cat.shape if mask1 is not None and mask2 is not None: mask_cat = torch.cat([mask1, mask2], dim=1) sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D) hidden_sorted = torch.gather(hidden_cat, 1, gather_idx) lengths = mask_cat.sum(dim=1) new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) else: new_mask = None hidden_sorted = hidden_cat return hidden_sorted, new_mask class AceStepConditionEncoder(nn.Module): def __init__( self, text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps=1e-6, dtype=None, device=None, operations=None ): super().__init__() Linear = get_layer_class(operations, "Linear") self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device) self.lyric_encoder = AceStepLyricEncoder( text_hidden_dim=text_hidden_dim, hidden_size=hidden_size, num_layers=num_lyric_layers, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, intermediate_size=intermediate_size, rms_norm_eps=rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.timbre_encoder = AceStepTimbreEncoder( timbre_hidden_dim=timbre_hidden_dim, hidden_size=hidden_size, num_layers=num_timbre_layers, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, intermediate_size=intermediate_size, rms_norm_eps=rms_norm_eps, dtype=dtype, device=device, operations=operations ) def forward( self, text_hidden_states=None, text_attention_mask=None, lyric_hidden_states=None, lyric_attention_mask=None, refer_audio_acoustic_hidden_states_packed=None, refer_audio_order_mask=None ): text_emb = self.text_projector(text_hidden_states) lyric_emb = self.lyric_encoder( inputs_embeds=lyric_hidden_states, attention_mask=lyric_attention_mask ) timbre_emb, timbre_mask = self.timbre_encoder( refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask ) merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask) final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask) return final_emb, final_mask # -------------------------------------------------------------------------------- # Main Diffusion Model (DiT) # -------------------------------------------------------------------------------- class AceStepDiTModel(nn.Module): def __init__( self, in_channels, hidden_size, num_layers, num_heads, num_kv_heads, head_dim, intermediate_size, patch_size, audio_acoustic_hidden_dim, layer_types=None, sliding_window=128, rms_norm_eps=1e-6, dtype=None, device=None, operations=None ): super().__init__() self.patch_size = patch_size self.rotary_emb = RotaryEmbedding( head_dim, base=1000000.0, dtype=dtype, device=device, operations=operations ) Conv1d = get_layer_class(operations, "Conv1d") ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d") Linear = get_layer_class(operations, "Linear") self.proj_in = nn.Sequential( nn.Identity(), Conv1d( in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device)) self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device) if layer_types is None: layer_types = ["full_attention"] * num_layers if len(layer_types) < num_layers: layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers)) self.layers = nn.ModuleList([ AceStepDiTLayer( hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, layer_type=layer_types[i], sliding_window=sliding_window, dtype=dtype, device=device, operations=operations ) for i in range(num_layers) ]) self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.proj_out = nn.Sequential( nn.Identity(), ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device) ) self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device)) def forward( self, hidden_states, timestep, timestep_r, attention_mask, encoder_hidden_states, encoder_attention_mask, context_latents ): temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype) temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype) temb = temb_t + temb_r timestep_proj = proj_t + proj_r x = torch.cat([context_latents, hidden_states], dim=-1) original_seq_len = x.shape[1] pad_length = 0 if x.shape[1] % self.patch_size != 0: pad_length = self.patch_size - (x.shape[1] % self.patch_size) x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0) x = x.transpose(1, 2) x = self.proj_in(x) x = x.transpose(1, 2) encoder_hidden_states = self.condition_embedder(encoder_hidden_states) seq_len = x.shape[1] cos, sin = self.rotary_emb(x, seq_len=seq_len) for layer in self.layers: x = layer( hidden_states=x, temb=timestep_proj, encoder_hidden_states=encoder_hidden_states, position_embeddings=(cos, sin), attention_mask=None, encoder_attention_mask=None ) shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) x = self.norm_out(x) * (1 + scale) + shift x = x.transpose(1, 2) x = self.proj_out(x) x = x.transpose(1, 2) x = x[:, :original_seq_len, :] return x class AttentionPooler(nn.Module): def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None): super().__init__() Linear = get_layer_class(operations, "Linear") self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) self.layers = nn.ModuleList([ AceStepEncoderLayer( hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps, dtype=dtype, device=device, operations=operations ) for _ in range(num_layers) ]) def forward(self, x): B, T, P, D = x.shape x = self.embed_tokens(x) special = self.special_token.expand(B, T, 1, -1) x = torch.cat([special, x], dim=2) x = x.view(B * T, P + 1, D) cos, sin = self.rotary_emb(x, seq_len=P + 1) for layer in self.layers: x = layer(x, (cos, sin)) x = self.norm(x) return x[:, 0, :].view(B, T, D) class FSQ(nn.Module): def __init__( self, levels, dim=None, device=None, dtype=None, operations=None ): super().__init__() _levels = torch.tensor(levels, dtype=torch.int32, device=device) self.register_buffer('_levels', _levels, persistent=False) _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0) self.register_buffer('_basis', _basis, persistent=False) self.codebook_dim = len(levels) self.dim = dim if dim is not None else self.codebook_dim requires_projection = self.dim != self.codebook_dim if requires_projection: self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype) self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype) else: self.project_in = nn.Identity() self.project_out = nn.Identity() self.codebook_size = self._levels.prod().item() indices = torch.arange(self.codebook_size, device=device) implicit_codebook = self._indices_to_codes(indices) if dtype is not None: implicit_codebook = implicit_codebook.to(dtype) self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) def bound(self, z): levels_minus_1 = (self._levels - 1).to(z.dtype) scale = 2. / levels_minus_1 bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 zhat = bracket.floor() bracket_ste = bracket + (zhat - bracket).detach() return scale * bracket_ste - 1. def _indices_to_codes(self, indices): indices = indices.unsqueeze(-1) codes_non_centered = (indices // self._basis) % self._levels return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. def codes_to_indices(self, zhat): zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1)) return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32) def forward(self, z): orig_dtype = z.dtype z = self.project_in(z) codes = self.bound(z) indices = self.codes_to_indices(codes) out = self.project_out(codes) return out.to(orig_dtype), indices class ResidualFSQ(nn.Module): def __init__( self, levels, num_quantizers, dim=None, bound_hard_clamp=True, device=None, dtype=None, operations=None, **kwargs ): super().__init__() codebook_dim = len(levels) dim = dim if dim is not None else codebook_dim requires_projection = codebook_dim != dim if requires_projection: self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype) self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype) else: self.project_in = nn.Identity() self.project_out = nn.Identity() self.layers = nn.ModuleList() levels_tensor = torch.tensor(levels, device=device) scales = [] for ind in range(num_quantizers): scale_val = levels_tensor.float() ** -ind scales.append(scale_val) self.layers.append(FSQ( levels=levels, dim=codebook_dim, device=device, dtype=dtype, operations=operations )) scales_tensor = torch.stack(scales) if dtype is not None: scales_tensor = scales_tensor.to(dtype) self.register_buffer('scales', scales_tensor, persistent=False) if bound_hard_clamp: val = 1 + (1 / (levels_tensor.float() - 1)) if dtype is not None: val = val.to(dtype) self.register_buffer('soft_clamp_input_value', val, persistent=False) def get_output_from_indices(self, indices, dtype=torch.float32): if indices.dim() == 2: indices = indices.unsqueeze(-1) all_codes = [] for i, layer in enumerate(self.layers): idx = indices[..., i].long() codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype)) all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype)) codes_summed = torch.stack(all_codes).sum(dim=0) return self.project_out(codes_summed) def forward(self, x): x = self.project_in(x) if hasattr(self, 'soft_clamp_input_value'): sc_val = self.soft_clamp_input_value.to(x.dtype) x = (x / sc_val).tanh() * sc_val quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) residual = x all_indices = [] for layer, scale in zip(self.layers, self.scales): scale = scale.to(residual.dtype) quantized, indices = layer(residual / scale) quantized = quantized * scale residual = residual - quantized.detach() quantized_out = quantized_out + quantized all_indices.append(indices) quantized_out = self.project_out(quantized_out) all_indices = torch.stack(all_indices, dim=-1) return quantized_out, all_indices class AceStepAudioTokenizer(nn.Module): def __init__( self, audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim, fsq_levels, fsq_input_num_quantizers, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None ): super().__init__() Linear = get_layer_class(operations, "Linear") self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device) self.attention_pooler = AttentionPooler( hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.pool_window_size = pool_window_size self.fsq_dim = fsq_dim self.quantizer = ResidualFSQ( dim=fsq_dim, levels=fsq_levels, num_quantizers=fsq_input_num_quantizers, bound_hard_clamp=True, dtype=dtype, device=device, operations=operations ) def forward(self, hidden_states): hidden_states = self.audio_acoustic_proj(hidden_states) hidden_states = self.attention_pooler(hidden_states) quantized, indices = self.quantizer(hidden_states) return quantized, indices def tokenize(self, x): B, T, D = x.shape P = self.pool_window_size if T % P != 0: pad = P - (T % P) x = F.pad(x, (0, 0, 0, pad)) T = x.shape[1] T_patch = T // P x = x.view(B, T_patch, P, D) quantized, indices = self.forward(x) return quantized, indices class AudioTokenDetokenizer(nn.Module): def __init__( self, hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers, head_dim, dtype=None, device=None, operations=None ): super().__init__() Linear = get_layer_class(operations, "Linear") self.pool_window_size = pool_window_size self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device)) self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) self.layers = nn.ModuleList([ AceStepEncoderLayer( hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6, dtype=dtype, device=device, operations=operations ) for _ in range(num_layers) ]) self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device) def forward(self, x): B, T, D = x.shape x = self.embed_tokens(x) x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1) x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype) x = x.view(B * T, self.pool_window_size, D) cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size) for layer in self.layers: x = layer(x, (cos, sin)) x = self.norm(x) x = self.proj_out(x) return x.view(B, T * self.pool_window_size, -1) class AceStepConditionGenerationModel(nn.Module): def __init__( self, in_channels=192, hidden_size=2048, text_hidden_dim=1024, timbre_hidden_dim=64, audio_acoustic_hidden_dim=64, num_dit_layers=24, num_lyric_layers=8, num_timbre_layers=4, num_tokenizer_layers=2, num_heads=16, num_kv_heads=8, head_dim=128, intermediate_size=6144, patch_size=2, pool_window_size=5, rms_norm_eps=1e-06, timestep_mu=-0.4, timestep_sigma=1.0, data_proportion=0.5, sliding_window=128, layer_types=None, fsq_dim=2048, fsq_levels=[8, 8, 8, 5, 5, 5], fsq_input_num_quantizers=1, audio_model=None, dtype=None, device=None, operations=None ): super().__init__() self.dtype = dtype self.timestep_mu = timestep_mu self.timestep_sigma = timestep_sigma self.data_proportion = data_proportion self.pool_window_size = pool_window_size if layer_types is None: layer_types = [] for i in range(num_dit_layers): layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention") self.decoder = AceStepDiTModel( in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim, intermediate_size, patch_size, audio_acoustic_hidden_dim, layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.encoder = AceStepConditionEncoder( text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.tokenizer = AceStepAudioTokenizer( audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.detokenizer = AudioTokenDetokenizer( hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, dtype=dtype, device=device, operations=operations ) self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) def prepare_condition( self, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=None, audio_codes=None ): encoder_hidden, encoder_mask = self.encoder( text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask ) if precomputed_lm_hints_25Hz is not None: lm_hints = precomputed_lm_hints_25Hz else: if audio_codes is not None: if audio_codes.shape[1] * 5 < src_latents.shape[1]: audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) else: assert False # TODO ? lm_hints = self.detokenizer(lm_hints_5Hz) lm_hints = lm_hints[:, :src_latents.shape[1], :] if is_covers is None: src_latents = lm_hints else: src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents) context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) return encoder_hidden, encoder_mask, context_latents def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs): text_attention_mask = None lyric_attention_mask = None refer_audio_order_mask = None attention_mask = None chunk_masks = None is_covers = None src_latents = None precomputed_lm_hints_25Hz = None lyric_hidden_states = lyric_embed text_hidden_states = context refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2) x = x.movedim(-1, -2) if refer_audio_order_mask is None: refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) if src_latents is None and is_covers is None: src_latents = x if chunk_masks is None: chunk_masks = torch.ones_like(x) enc_hidden, enc_mask, context_latents = self.prepare_condition( text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes ) out = self.decoder(hidden_states=x, timestep=timestep, timestep_r=timestep, attention_mask=attention_mask, encoder_hidden_states=enc_hidden, encoder_attention_mask=enc_mask, context_latents=context_latents ) return out.movedim(-1, -2)