mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-04 02:30:21 +08:00
1094 lines
40 KiB
Python
1094 lines
40 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import itertools
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
import comfy.model_management
|
|
from comfy.ldm.flux.layers import timestep_embedding
|
|
|
|
def get_layer_class(operations, layer_name):
|
|
if operations is not None and hasattr(operations, layer_name):
|
|
return getattr(operations, layer_name)
|
|
return getattr(nn, layer_name)
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.base = base
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype)
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
self.max_seq_len_cached = seq_len
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
|
def forward(self, x, seq_len=None):
|
|
if seq_len > self.max_seq_len_cached:
|
|
self._set_cos_sin_cache(seq_len, x.device, x.dtype)
|
|
return (
|
|
self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
|
|
self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device),
|
|
)
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
cos = cos.unsqueeze(0).unsqueeze(0)
|
|
sin = sin.unsqueeze(0).unsqueeze(0)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
|
|
self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
|
|
self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)
|
|
self.act_fn = nn.SiLU()
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device)
|
|
self.act1 = nn.SiLU()
|
|
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device)
|
|
self.in_channels = in_channels
|
|
self.act2 = nn.SiLU()
|
|
self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device)
|
|
self.scale = scale
|
|
|
|
def forward(self, t, dtype=None):
|
|
t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale)
|
|
temb = self.linear_1(t_freq.to(dtype=dtype))
|
|
temb = self.act1(temb)
|
|
temb = self.linear_2(temb)
|
|
timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1])
|
|
return temb, timestep_proj
|
|
|
|
class AceStepAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
rms_norm_eps=1e-6,
|
|
is_cross_attention=False,
|
|
sliding_window=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.head_dim = head_dim
|
|
self.is_cross_attention = is_cross_attention
|
|
self.sliding_window = sliding_window
|
|
|
|
Linear = get_layer_class(operations, "Linear")
|
|
|
|
self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device)
|
|
self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device)
|
|
self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device)
|
|
self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device)
|
|
|
|
self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
position_embeddings=None,
|
|
):
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
query_states = self.q_norm(query_states)
|
|
query_states = query_states.transpose(1, 2)
|
|
|
|
if self.is_cross_attention and encoder_hidden_states is not None:
|
|
bsz_enc, kv_len, _ = encoder_hidden_states.size()
|
|
key_states = self.k_proj(encoder_hidden_states)
|
|
value_states = self.v_proj(encoder_hidden_states)
|
|
|
|
key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim)
|
|
key_states = self.k_norm(key_states)
|
|
value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim)
|
|
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
else:
|
|
kv_len = q_len
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
|
|
key_states = self.k_norm(key_states)
|
|
value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
|
|
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
if position_embeddings is not None:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
n_rep = self.num_heads // self.num_kv_heads
|
|
if n_rep > 1:
|
|
key_states = key_states.repeat_interleave(n_rep, dim=1)
|
|
value_states = value_states.repeat_interleave(n_rep, dim=1)
|
|
|
|
attn_bias = None
|
|
if self.sliding_window is not None and not self.is_cross_attention:
|
|
indices = torch.arange(q_len, device=query_states.device)
|
|
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
|
in_window = torch.abs(diff) <= self.sliding_window
|
|
|
|
window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype)
|
|
min_value = torch.finfo(query_states.dtype).min
|
|
window_bias.masked_fill_(~in_window, min_value)
|
|
|
|
window_bias = window_bias.unsqueeze(0).unsqueeze(0)
|
|
|
|
if attn_bias is not None:
|
|
if attn_bias.dtype == torch.bool:
|
|
base_bias = torch.zeros_like(window_bias)
|
|
base_bias.masked_fill_(~attn_bias, min_value)
|
|
attn_bias = base_bias + window_bias
|
|
else:
|
|
attn_bias = attn_bias + window_bias
|
|
else:
|
|
attn_bias = window_bias
|
|
|
|
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
class AceStepDiTLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
intermediate_size,
|
|
rms_norm_eps=1e-6,
|
|
layer_type="full_attention",
|
|
sliding_window=128,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self_attn_window = sliding_window if layer_type == "sliding_attention" else None
|
|
|
|
self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.self_attn = AceStepAttention(
|
|
hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
|
|
is_cross_attention=False, sliding_window=self_attn_window,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.cross_attn = AceStepAttention(
|
|
hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
|
|
is_cross_attention=True, dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device))
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
temb,
|
|
encoder_hidden_states,
|
|
position_embeddings,
|
|
attention_mask=None,
|
|
encoder_attention_mask=None
|
|
):
|
|
modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb
|
|
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1)
|
|
|
|
norm_hidden = self.self_attn_norm(hidden_states)
|
|
norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa
|
|
|
|
attn_out = self.self_attn(
|
|
norm_hidden,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=attention_mask
|
|
)
|
|
hidden_states = hidden_states + attn_out * gate_msa
|
|
|
|
norm_hidden = self.cross_attn_norm(hidden_states)
|
|
attn_out = self.cross_attn(
|
|
norm_hidden,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask
|
|
)
|
|
hidden_states = hidden_states + attn_out
|
|
|
|
norm_hidden = self.mlp_norm(hidden_states)
|
|
norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa
|
|
|
|
mlp_out = self.mlp(norm_hidden)
|
|
hidden_states = hidden_states + mlp_out * c_gate_msa
|
|
|
|
return hidden_states
|
|
|
|
class AceStepEncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
intermediate_size,
|
|
rms_norm_eps=1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.self_attn = AceStepAttention(
|
|
hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
|
|
is_cross_attention=False, dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, hidden_states, position_embeddings, attention_mask=None):
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=attention_mask
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
class AceStepLyricEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
text_hidden_dim,
|
|
hidden_size,
|
|
num_layers,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
intermediate_size,
|
|
rms_norm_eps=1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
|
|
self.rotary_emb = RotaryEmbedding(
|
|
head_dim,
|
|
base=1000000.0,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations
|
|
)
|
|
|
|
self.layers = nn.ModuleList([
|
|
AceStepEncoderLayer(
|
|
hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
def forward(self, inputs_embeds, attention_mask=None):
|
|
hidden_states = self.embed_tokens(inputs_embeds)
|
|
seq_len = hidden_states.shape[1]
|
|
cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len)
|
|
position_embeddings = (cos, sin)
|
|
|
|
for layer in self.layers:
|
|
hidden_states = layer(
|
|
hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=attention_mask
|
|
)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
class AceStepTimbreEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
timbre_hidden_dim,
|
|
hidden_size,
|
|
num_layers,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
intermediate_size,
|
|
rms_norm_eps=1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
|
|
self.rotary_emb = RotaryEmbedding(
|
|
head_dim,
|
|
base=1000000.0,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations
|
|
)
|
|
|
|
self.layers = nn.ModuleList([
|
|
AceStepEncoderLayer(
|
|
hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
|
|
|
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
|
|
N, d = timbre_embs_packed.shape
|
|
device = timbre_embs_packed.device
|
|
B = N
|
|
counts = torch.bincount(refer_audio_order_mask, minlength=B)
|
|
max_count = counts.max().item()
|
|
|
|
sorted_indices = torch.argsort(
|
|
refer_audio_order_mask * N + torch.arange(N, device=device),
|
|
stable=True
|
|
)
|
|
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
|
|
|
|
positions = torch.arange(N, device=device)
|
|
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
|
|
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
|
|
|
|
inverse_indices = torch.empty_like(sorted_indices)
|
|
inverse_indices[sorted_indices] = torch.arange(N, device=device)
|
|
positions_in_batch = positions_in_sorted[inverse_indices]
|
|
|
|
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
|
|
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype)
|
|
|
|
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
|
|
timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d)
|
|
|
|
mask_flat = (one_hot.sum(dim=0) > 0).long()
|
|
new_mask = mask_flat.view(B, max_count)
|
|
return timbre_embs_unpack, new_mask
|
|
|
|
def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None):
|
|
hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed)
|
|
if hidden_states.dim() == 2:
|
|
hidden_states = hidden_states.unsqueeze(0)
|
|
|
|
seq_len = hidden_states.shape[1]
|
|
cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len)
|
|
|
|
for layer in self.layers:
|
|
hidden_states = layer(
|
|
hidden_states,
|
|
position_embeddings=(cos, sin),
|
|
attention_mask=attention_mask
|
|
)
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
flat_states = hidden_states[:, 0, :]
|
|
unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask)
|
|
return unpacked_embs, unpacked_mask
|
|
|
|
|
|
def pack_sequences(hidden1, hidden2, mask1, mask2):
|
|
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
|
|
B, L, D = hidden_cat.shape
|
|
|
|
if mask1 is not None and mask2 is not None:
|
|
mask_cat = torch.cat([mask1, mask2], dim=1)
|
|
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
|
|
gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D)
|
|
hidden_sorted = torch.gather(hidden_cat, 1, gather_idx)
|
|
lengths = mask_cat.sum(dim=1)
|
|
new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
|
|
else:
|
|
new_mask = None
|
|
hidden_sorted = hidden_cat
|
|
|
|
return hidden_sorted, new_mask
|
|
|
|
class AceStepConditionEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
text_hidden_dim,
|
|
timbre_hidden_dim,
|
|
hidden_size,
|
|
num_lyric_layers,
|
|
num_timbre_layers,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
intermediate_size,
|
|
rms_norm_eps=1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device)
|
|
|
|
self.lyric_encoder = AceStepLyricEncoder(
|
|
text_hidden_dim=text_hidden_dim,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_lyric_layers,
|
|
num_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_dim,
|
|
intermediate_size=intermediate_size,
|
|
rms_norm_eps=rms_norm_eps,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations
|
|
)
|
|
|
|
self.timbre_encoder = AceStepTimbreEncoder(
|
|
timbre_hidden_dim=timbre_hidden_dim,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_timbre_layers,
|
|
num_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_dim,
|
|
intermediate_size=intermediate_size,
|
|
rms_norm_eps=rms_norm_eps,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
text_hidden_states=None,
|
|
text_attention_mask=None,
|
|
lyric_hidden_states=None,
|
|
lyric_attention_mask=None,
|
|
refer_audio_acoustic_hidden_states_packed=None,
|
|
refer_audio_order_mask=None
|
|
):
|
|
text_emb = self.text_projector(text_hidden_states)
|
|
|
|
lyric_emb = self.lyric_encoder(
|
|
inputs_embeds=lyric_hidden_states,
|
|
attention_mask=lyric_attention_mask
|
|
)
|
|
|
|
timbre_emb, timbre_mask = self.timbre_encoder(
|
|
refer_audio_acoustic_hidden_states_packed,
|
|
refer_audio_order_mask
|
|
)
|
|
|
|
merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask)
|
|
final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask)
|
|
|
|
return final_emb, final_mask
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# Main Diffusion Model (DiT)
|
|
# --------------------------------------------------------------------------------
|
|
|
|
class AceStepDiTModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
hidden_size,
|
|
num_layers,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
intermediate_size,
|
|
patch_size,
|
|
audio_acoustic_hidden_dim,
|
|
layer_types=None,
|
|
sliding_window=128,
|
|
rms_norm_eps=1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.rotary_emb = RotaryEmbedding(
|
|
head_dim,
|
|
base=1000000.0,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations
|
|
)
|
|
|
|
Conv1d = get_layer_class(operations, "Conv1d")
|
|
ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d")
|
|
Linear = get_layer_class(operations, "Linear")
|
|
|
|
self.proj_in = nn.Sequential(
|
|
nn.Identity(),
|
|
Conv1d(
|
|
in_channels, hidden_size, kernel_size=patch_size, stride=patch_size,
|
|
dtype=dtype, device=device))
|
|
|
|
self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
|
|
|
if layer_types is None:
|
|
layer_types = ["full_attention"] * num_layers
|
|
|
|
if len(layer_types) < num_layers:
|
|
layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers))
|
|
|
|
self.layers = nn.ModuleList([
|
|
AceStepDiTLayer(
|
|
hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
|
|
layer_type=layer_types[i],
|
|
sliding_window=sliding_window,
|
|
dtype=dtype, device=device, operations=operations
|
|
) for i in range(num_layers)
|
|
])
|
|
|
|
self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.proj_out = nn.Sequential(
|
|
nn.Identity(),
|
|
ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device)
|
|
)
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device))
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
timestep,
|
|
timestep_r,
|
|
attention_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
context_latents
|
|
):
|
|
temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype)
|
|
temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype)
|
|
temb = temb_t + temb_r
|
|
timestep_proj = proj_t + proj_r
|
|
|
|
x = torch.cat([context_latents, hidden_states], dim=-1)
|
|
original_seq_len = x.shape[1]
|
|
|
|
pad_length = 0
|
|
if x.shape[1] % self.patch_size != 0:
|
|
pad_length = self.patch_size - (x.shape[1] % self.patch_size)
|
|
x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0)
|
|
|
|
x = x.transpose(1, 2)
|
|
x = self.proj_in(x)
|
|
x = x.transpose(1, 2)
|
|
|
|
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
|
|
|
seq_len = x.shape[1]
|
|
cos, sin = self.rotary_emb(x, seq_len=seq_len)
|
|
|
|
for layer in self.layers:
|
|
x = layer(
|
|
hidden_states=x,
|
|
temb=timestep_proj,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
position_embeddings=(cos, sin),
|
|
attention_mask=None,
|
|
encoder_attention_mask=None
|
|
)
|
|
|
|
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
|
x = self.norm_out(x) * (1 + scale) + shift
|
|
|
|
x = x.transpose(1, 2)
|
|
x = self.proj_out(x)
|
|
x = x.transpose(1, 2)
|
|
|
|
x = x[:, :original_seq_len, :]
|
|
return x
|
|
|
|
|
|
class AttentionPooler(nn.Module):
|
|
def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
|
self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
|
|
self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations)
|
|
self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
|
|
|
|
self.layers = nn.ModuleList([
|
|
AceStepEncoderLayer(
|
|
hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
def forward(self, x):
|
|
B, T, P, D = x.shape
|
|
x = self.embed_tokens(x)
|
|
special = self.special_token.expand(B, T, 1, -1)
|
|
x = torch.cat([special, x], dim=2)
|
|
x = x.view(B * T, P + 1, D)
|
|
|
|
cos, sin = self.rotary_emb(x, seq_len=P + 1)
|
|
for layer in self.layers:
|
|
x = layer(x, (cos, sin))
|
|
|
|
x = self.norm(x)
|
|
return x[:, 0, :].view(B, T, D)
|
|
|
|
|
|
class FSQ(nn.Module):
|
|
def __init__(
|
|
self,
|
|
levels,
|
|
dim=None,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
_levels = torch.tensor(levels, dtype=torch.int32, device=device)
|
|
self.register_buffer('_levels', _levels, persistent=False)
|
|
|
|
_basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0)
|
|
self.register_buffer('_basis', _basis, persistent=False)
|
|
|
|
self.codebook_dim = len(levels)
|
|
self.dim = dim if dim is not None else self.codebook_dim
|
|
|
|
requires_projection = self.dim != self.codebook_dim
|
|
if requires_projection:
|
|
self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype)
|
|
self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype)
|
|
else:
|
|
self.project_in = nn.Identity()
|
|
self.project_out = nn.Identity()
|
|
|
|
self.codebook_size = self._levels.prod().item()
|
|
|
|
indices = torch.arange(self.codebook_size, device=device)
|
|
implicit_codebook = self._indices_to_codes(indices)
|
|
|
|
if dtype is not None:
|
|
implicit_codebook = implicit_codebook.to(dtype)
|
|
|
|
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
|
|
|
|
def bound(self, z):
|
|
levels_minus_1 = (self._levels - 1).to(z.dtype)
|
|
scale = 2. / levels_minus_1
|
|
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
|
|
|
|
zhat = bracket.floor()
|
|
bracket_ste = bracket + (zhat - bracket).detach()
|
|
|
|
return scale * bracket_ste - 1.
|
|
|
|
def _indices_to_codes(self, indices):
|
|
indices = indices.unsqueeze(-1)
|
|
codes_non_centered = (indices // self._basis) % self._levels
|
|
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
|
|
|
|
def codes_to_indices(self, zhat):
|
|
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
|
|
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
|
|
|
|
def forward(self, z):
|
|
orig_dtype = z.dtype
|
|
z = self.project_in(z)
|
|
|
|
codes = self.bound(z)
|
|
indices = self.codes_to_indices(codes)
|
|
|
|
out = self.project_out(codes)
|
|
return out.to(orig_dtype), indices
|
|
|
|
|
|
class ResidualFSQ(nn.Module):
|
|
def __init__(
|
|
self,
|
|
levels,
|
|
num_quantizers,
|
|
dim=None,
|
|
bound_hard_clamp=True,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
codebook_dim = len(levels)
|
|
dim = dim if dim is not None else codebook_dim
|
|
|
|
requires_projection = codebook_dim != dim
|
|
if requires_projection:
|
|
self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype)
|
|
self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype)
|
|
else:
|
|
self.project_in = nn.Identity()
|
|
self.project_out = nn.Identity()
|
|
|
|
self.layers = nn.ModuleList()
|
|
levels_tensor = torch.tensor(levels, device=device)
|
|
scales = []
|
|
|
|
for ind in range(num_quantizers):
|
|
scale_val = levels_tensor.float() ** -ind
|
|
scales.append(scale_val)
|
|
|
|
self.layers.append(FSQ(
|
|
levels=levels,
|
|
dim=codebook_dim,
|
|
device=device,
|
|
dtype=dtype,
|
|
operations=operations
|
|
))
|
|
|
|
scales_tensor = torch.stack(scales)
|
|
if dtype is not None:
|
|
scales_tensor = scales_tensor.to(dtype)
|
|
self.register_buffer('scales', scales_tensor, persistent=False)
|
|
|
|
if bound_hard_clamp:
|
|
val = 1 + (1 / (levels_tensor.float() - 1))
|
|
if dtype is not None:
|
|
val = val.to(dtype)
|
|
self.register_buffer('soft_clamp_input_value', val, persistent=False)
|
|
|
|
def get_output_from_indices(self, indices, dtype=torch.float32):
|
|
if indices.dim() == 2:
|
|
indices = indices.unsqueeze(-1)
|
|
|
|
all_codes = []
|
|
for i, layer in enumerate(self.layers):
|
|
idx = indices[..., i].long()
|
|
codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype))
|
|
all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype))
|
|
|
|
codes_summed = torch.stack(all_codes).sum(dim=0)
|
|
return self.project_out(codes_summed)
|
|
|
|
def forward(self, x):
|
|
x = self.project_in(x)
|
|
|
|
if hasattr(self, 'soft_clamp_input_value'):
|
|
sc_val = self.soft_clamp_input_value.to(x.dtype)
|
|
x = (x / sc_val).tanh() * sc_val
|
|
|
|
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
|
|
residual = x
|
|
all_indices = []
|
|
|
|
for layer, scale in zip(self.layers, self.scales):
|
|
scale = scale.to(residual.dtype)
|
|
|
|
quantized, indices = layer(residual / scale)
|
|
quantized = quantized * scale
|
|
|
|
residual = residual - quantized.detach()
|
|
quantized_out = quantized_out + quantized
|
|
all_indices.append(indices)
|
|
|
|
quantized_out = self.project_out(quantized_out)
|
|
all_indices = torch.stack(all_indices, dim=-1)
|
|
|
|
return quantized_out, all_indices
|
|
|
|
|
|
class AceStepAudioTokenizer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
audio_acoustic_hidden_dim,
|
|
hidden_size,
|
|
pool_window_size,
|
|
fsq_dim,
|
|
fsq_levels,
|
|
fsq_input_num_quantizers,
|
|
num_layers,
|
|
head_dim,
|
|
rms_norm_eps,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
self.attention_pooler = AttentionPooler(
|
|
hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.pool_window_size = pool_window_size
|
|
self.fsq_dim = fsq_dim
|
|
self.quantizer = ResidualFSQ(
|
|
dim=fsq_dim,
|
|
levels=fsq_levels,
|
|
num_quantizers=fsq_input_num_quantizers,
|
|
bound_hard_clamp=True,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.audio_acoustic_proj(hidden_states)
|
|
hidden_states = self.attention_pooler(hidden_states)
|
|
quantized, indices = self.quantizer(hidden_states)
|
|
return quantized, indices
|
|
|
|
def tokenize(self, x):
|
|
B, T, D = x.shape
|
|
P = self.pool_window_size
|
|
|
|
if T % P != 0:
|
|
pad = P - (T % P)
|
|
x = F.pad(x, (0, 0, 0, pad))
|
|
T = x.shape[1]
|
|
|
|
T_patch = T // P
|
|
x = x.view(B, T_patch, P, D)
|
|
|
|
quantized, indices = self.forward(x)
|
|
return quantized, indices
|
|
|
|
|
|
class AudioTokenDetokenizer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
pool_window_size,
|
|
audio_acoustic_hidden_dim,
|
|
num_layers,
|
|
head_dim,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
Linear = get_layer_class(operations, "Linear")
|
|
self.pool_window_size = pool_window_size
|
|
self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
|
self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device))
|
|
self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations)
|
|
self.layers = nn.ModuleList([
|
|
AceStepEncoderLayer(
|
|
hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
|
self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
B, T, D = x.shape
|
|
x = self.embed_tokens(x)
|
|
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
|
|
x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype)
|
|
x = x.view(B * T, self.pool_window_size, D)
|
|
|
|
cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size)
|
|
for layer in self.layers:
|
|
x = layer(x, (cos, sin))
|
|
|
|
x = self.norm(x)
|
|
x = self.proj_out(x)
|
|
return x.view(B, T * self.pool_window_size, -1)
|
|
|
|
|
|
class AceStepConditionGenerationModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=192,
|
|
hidden_size=2048,
|
|
text_hidden_dim=1024,
|
|
timbre_hidden_dim=64,
|
|
audio_acoustic_hidden_dim=64,
|
|
num_dit_layers=24,
|
|
num_lyric_layers=8,
|
|
num_timbre_layers=4,
|
|
num_tokenizer_layers=2,
|
|
num_heads=16,
|
|
num_kv_heads=8,
|
|
head_dim=128,
|
|
intermediate_size=6144,
|
|
patch_size=2,
|
|
pool_window_size=5,
|
|
rms_norm_eps=1e-06,
|
|
timestep_mu=-0.4,
|
|
timestep_sigma=1.0,
|
|
data_proportion=0.5,
|
|
sliding_window=128,
|
|
layer_types=None,
|
|
fsq_dim=2048,
|
|
fsq_levels=[8, 8, 8, 5, 5, 5],
|
|
fsq_input_num_quantizers=1,
|
|
audio_model=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.timestep_mu = timestep_mu
|
|
self.timestep_sigma = timestep_sigma
|
|
self.data_proportion = data_proportion
|
|
self.pool_window_size = pool_window_size
|
|
|
|
if layer_types is None:
|
|
layer_types = []
|
|
for i in range(num_dit_layers):
|
|
layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention")
|
|
|
|
self.decoder = AceStepDiTModel(
|
|
in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim,
|
|
intermediate_size, patch_size, audio_acoustic_hidden_dim,
|
|
layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.encoder = AceStepConditionEncoder(
|
|
text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers,
|
|
num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.tokenizer = AceStepAudioTokenizer(
|
|
audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.detokenizer = AudioTokenDetokenizer(
|
|
hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
|
|
|
|
def prepare_condition(
|
|
self,
|
|
text_hidden_states, text_attention_mask,
|
|
lyric_hidden_states, lyric_attention_mask,
|
|
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask,
|
|
src_latents, chunk_masks, is_covers,
|
|
precomputed_lm_hints_25Hz=None,
|
|
audio_codes=None
|
|
):
|
|
encoder_hidden, encoder_mask = self.encoder(
|
|
text_hidden_states, text_attention_mask,
|
|
lyric_hidden_states, lyric_attention_mask,
|
|
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask
|
|
)
|
|
|
|
if precomputed_lm_hints_25Hz is not None:
|
|
lm_hints = precomputed_lm_hints_25Hz
|
|
else:
|
|
if audio_codes is not None:
|
|
if audio_codes.shape[1] * 5 < src_latents.shape[1]:
|
|
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
|
|
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
|
|
else:
|
|
assert False
|
|
# TODO ?
|
|
|
|
lm_hints = self.detokenizer(lm_hints_5Hz)
|
|
|
|
lm_hints = lm_hints[:, :src_latents.shape[1], :]
|
|
if is_covers is None:
|
|
src_latents = lm_hints
|
|
else:
|
|
src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
|
|
|
|
context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
|
|
|
|
return encoder_hidden, encoder_mask, context_latents
|
|
|
|
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
|
|
text_attention_mask = None
|
|
lyric_attention_mask = None
|
|
refer_audio_order_mask = None
|
|
attention_mask = None
|
|
chunk_masks = None
|
|
is_covers = None
|
|
src_latents = None
|
|
precomputed_lm_hints_25Hz = None
|
|
lyric_hidden_states = lyric_embed
|
|
text_hidden_states = context
|
|
refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2)
|
|
|
|
x = x.movedim(-1, -2)
|
|
|
|
if refer_audio_order_mask is None:
|
|
refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
|
|
|
|
if src_latents is None and is_covers is None:
|
|
src_latents = x
|
|
|
|
if chunk_masks is None:
|
|
chunk_masks = torch.ones_like(x)
|
|
|
|
enc_hidden, enc_mask, context_latents = self.prepare_condition(
|
|
text_hidden_states, text_attention_mask,
|
|
lyric_hidden_states, lyric_attention_mask,
|
|
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask,
|
|
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
|
)
|
|
|
|
out = self.decoder(hidden_states=x,
|
|
timestep=timestep,
|
|
timestep_r=timestep,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=enc_hidden,
|
|
encoder_attention_mask=enc_mask,
|
|
context_latents=context_latents
|
|
)
|
|
|
|
return out.movedim(-1, -2)
|