From f9c84c94b43855ec46e26afc944c06fb121cb9a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 May 2026 08:34:22 -0700 Subject: [PATCH] Support Stable Audio 3 model. (#14010) --- comfy/latent_formats.py | 5 + comfy/ldm/audio/dit.py | 250 +++++++++++++--- comfy/ldm/audio/embedders.py | 31 +- comfy/ldm/audio/vae_sa3.py | 533 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 79 ++++++ comfy/model_detection.py | 39 +++ comfy/sd.py | 37 +++ comfy/supported_models.py | 25 ++ comfy/text_encoders/sa3.py | 207 ++++++++++++++ 9 files changed, 1161 insertions(+), 45 deletions(-) create mode 100644 comfy/ldm/audio/vae_sa3.py create mode 100644 comfy/text_encoders/sa3.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6e37080bb..75d459b59 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -152,6 +152,11 @@ class StableAudio1(LatentFormat): latent_dimensions = 1 temporal_downscale_ratio = 2048 +class StableAudio3(LatentFormat): + latent_channels = 256 + latent_dimensions = 1 + temporal_downscale_ratio = 4096 + class Flux(SD3): latent_channels = 16 def __init__(self): diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index ca865189e..a6258b755 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -10,6 +10,17 @@ from torch import nn from torch.nn import functional as F import math import comfy.ops +from .embedders import ExpoFourierFeatures + + +def _left_pad_to_match(emb, target_len): + emb_len = emb.shape[-2] + if emb_len < target_len: + return F.pad(emb, (0, 0, target_len - emb_len, 0), value=0.) + elif emb_len > target_len: + return emb[:, -target_len:, :] + return emb + class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, std=1., dtype=None, device=None): @@ -22,6 +33,7 @@ class FourierFeatures(nn.Module): f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input) return torch.cat([f.cos(), f.sin()], dim=-1) + # norms class LayerNorm(nn.Module): def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None): @@ -43,6 +55,16 @@ class LayerNorm(nn.Module): beta = comfy.ops.cast_to_input(beta, x) return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta) + +class RMSNorm(nn.Module): + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + def forward(self, x): + return F.rms_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x)) + + class GLU(nn.Module): def __init__( self, @@ -236,13 +258,6 @@ class FeedForward(nn.Module): linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device) - # # init last linear layer to 0 - # if zero_init_output: - # nn.init.zeros_(linear_out.weight) - # if not no_bias: - # nn.init.zeros_(linear_out.bias) - - self.ff = nn.Sequential( linear_in, rearrange('b d n -> b n d') if use_conv else nn.Identity(), @@ -261,8 +276,10 @@ class Attention(nn.Module): dim_context = None, causal = False, zero_init_output=True, - qk_norm = False, + qk_norm = "none", + differential = False, natten_kernel_size = None, + feat_scale = False, dtype=None, device=None, operations=None, @@ -271,6 +288,7 @@ class Attention(nn.Module): self.dim = dim self.dim_heads = dim_heads self.causal = causal + self.differential = differential dim_kv = dim_context if dim_context is not None else dim @@ -278,18 +296,37 @@ class Attention(nn.Module): self.kv_heads = dim_kv // dim_heads if dim_context is not None: - self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) - self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) + if differential: + self.to_q = operations.Linear(dim, dim * 2, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 3, bias=False, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) else: - self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) + if differential: + self.to_qkv = operations.Linear(dim, dim * 5, bias=False, dtype=dtype, device=device) + else: + self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) - # if zero_init_output: - # nn.init.zeros_(self.to_out.weight) - + # Accept bool for backward compat + if isinstance(qk_norm, bool): + qk_norm = "l2" if qk_norm else "none" self.qk_norm = qk_norm + if self.qk_norm == "ln": + self.q_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif self.qk_norm == "rms": + self.q_norm = RMSNorm(dim_heads, dtype=dtype, device=device) + self.k_norm = RMSNorm(dim_heads, dtype=dtype, device=device) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + self.lambda_hf = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) def forward( self, @@ -306,22 +343,51 @@ class Attention(nn.Module): kv_input = context if has_context else x if hasattr(self, 'to_q'): - # Use separate linear projections for q and k/v - q = self.to_q(x) - q = rearrange(q, 'b n (h d) -> b h n d', h = h) + if self.differential: + # cross-attention differential: to_q → (q, q_diff), to_kv → (k, k_diff, v) + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + q_diff = rearrange(q_diff, 'b n (h d) -> b h n d', h=h) + q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k = rearrange(k, 'b n (h d) -> b h n d', h=kv_h) + k_diff = rearrange(k_diff, 'b n (h d) -> b h n d', h=kv_h) + v = rearrange(v, 'b n (h d) -> b h n d', h=kv_h) + k = torch.stack([k, k_diff], dim=1) # (B, 2, H, M, D) + else: + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) - k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) else: - # Use fused linear projection - q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + if self.differential: + # self-attention differential: to_qkv → (q, k, v, q_diff, k_diff) + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), + (q, k, v, q_diff, k_diff) + ) + q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D) + k = torch.stack([k, k_diff], dim=1) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # Normalize q and k for cosine sim attention - if self.qk_norm: + if self.qk_norm == "l2": q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) + elif self.qk_norm == "rms": + q_type, k_type = q.dtype, k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + elif self.qk_norm != 'none': + q = self.q_norm(q) + k = self.k_norm(k) if rotary_pos_emb is not None and not has_context: freqs, _ = rotary_pos_emb @@ -364,9 +430,24 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + if self.differential: + q, q_diff = q.unbind(dim=1) + k, k_diff = k.unbind(dim=1) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) + out = out - out_diff + else: + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out = self.to_out(out) + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + comfy.ops.cast_to_input(self.lambda_dc, out) * out_dc + comfy.ops.cast_to_input(self.lambda_hf, out) * out_hf + if mask is not None: mask = rearrange(mask, 'b n -> b n 1') out = out.masked_fill(~mask, 0.) @@ -417,11 +498,14 @@ class TransformerBlock(nn.Module): cross_attend = False, dim_context = None, global_cond_dim = None, + global_cond_shared_embed = False, + local_add_cond_dim = None, causal = False, zero_init_branch_outputs = True, conformer = False, layer_ix = -1, remove_norms = False, + norm_type = "layer_norm", attn_kwargs = {}, ff_kwargs = {}, norm_kwargs = {}, @@ -436,8 +520,20 @@ class TransformerBlock(nn.Module): self.cross_attend = cross_attend self.dim_context = dim_context self.causal = causal + self.global_cond_shared_embed = global_cond_shared_embed - self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + norm_layer_map = { + "layer_norm": LayerNorm, + "rms_norm": RMSNorm, + } + norm_cls = norm_layer_map.get(norm_type, LayerNorm) + + def make_norm(): + if remove_norms: + return nn.Identity() + return norm_cls(dim, dtype=dtype, device=device, **norm_kwargs) + + self.pre_norm = make_norm() self.self_attn = Attention( dim, @@ -451,7 +547,7 @@ class TransformerBlock(nn.Module): ) if cross_attend: - self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attend_norm = make_norm() self.cross_attn = Attention( dim, dim_heads = dim_heads, @@ -464,37 +560,56 @@ class TransformerBlock(nn.Module): **attn_kwargs ) - self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() - self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs) + self.ff_norm = make_norm() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations, **ff_kwargs) self.layer_ix = layer_ix self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None - self.global_cond_dim = global_cond_dim + # Global conditioning + self.has_global_cond = (global_cond_dim is not None) or global_cond_shared_embed - if global_cond_dim is not None: + if global_cond_shared_embed: + # SA3 style: learnable per-block additive bias; global_cond is pre-projected to (B, dim*6) + self.to_scale_shift_gate = nn.Parameter(torch.empty(dim * 6, device=device, dtype=dtype)) + elif global_cond_dim is not None: + # SA1 style: per-block MLP projects global_cond → (B, dim*6) self.to_scale_shift_gate = nn.Sequential( nn.SiLU(), - nn.Linear(global_cond_dim, dim * 6, bias=False) + operations.Linear(global_cond_dim, dim * 6, bias=False, device=device, dtype=dtype) ) - nn.init.zeros_(self.to_scale_shift_gate[1].weight) - #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + # Local additive conditioning (e.g. inpaint mask + masked latent) + self.local_add_cond_dim = local_add_cond_dim + if local_add_cond_dim is not None: + self.to_local_embed = nn.Sequential( + operations.Linear(local_add_cond_dim, dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(dim, dim, bias=True, dtype=dtype, device=device), + ) + else: + self.to_local_embed = None def forward( self, x, context = None, global_cond=None, + local_add_cond=None, mask = None, context_mask = None, rotary_pos_emb = None, transformer_options={} ): - if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + if self.has_global_cond and global_cond is not None: + if self.global_cond_shared_embed: + # global_cond already has shape (B, dim*6) + ssg = (comfy.ops.cast_to_input(self.to_scale_shift_gate, global_cond) + global_cond).unsqueeze(1) + else: + ssg = self.to_scale_shift_gate(global_cond).unsqueeze(1) - scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = ssg.chunk(6, dim = -1) # self-attention with adaLN residual = x @@ -510,6 +625,9 @@ class TransformerBlock(nn.Module): if self.conformer is not None: x = x + self.conformer(x) + if local_add_cond is not None and self.to_local_embed is not None: + x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2]) + # feedforward with adaLN residual = x x = self.ff_norm(x) @@ -527,6 +645,9 @@ class TransformerBlock(nn.Module): if self.conformer is not None: x = x + self.conformer(x) + if local_add_cond is not None and self.to_local_embed is not None: + x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2]) + x = x + self.ff(self.ff_norm(x)) return x @@ -543,6 +664,8 @@ class ContinuousTransformer(nn.Module): cross_attend=False, cond_token_dim=None, global_cond_dim=None, + global_cond_shared_embed=False, + local_add_cond_dim=None, causal=False, rotary_pos_emb=True, zero_init_branch_outputs=True, @@ -550,6 +673,7 @@ class ContinuousTransformer(nn.Module): use_sinusoidal_emb=False, use_abs_pos_emb=False, abs_pos_emb_max_length=10000, + num_memory_tokens=0, dtype=None, device=None, operations=None, @@ -562,6 +686,8 @@ class ContinuousTransformer(nn.Module): self.depth = depth self.causal = causal self.layers = nn.ModuleList([]) + self.num_memory_tokens = num_memory_tokens + self.global_cond_shared_embed = global_cond_shared_embed self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity() self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() @@ -577,7 +703,22 @@ class ContinuousTransformer(nn.Module): self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: - self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + num_memory_tokens) + + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.empty(num_memory_tokens, dim, device=device, dtype=dtype)) + + # Shared global-cond embedder (SA3 style): projects (B, global_cond_dim) → (B, dim*6) + self.global_cond_embedder = None + if global_cond_shared_embed and global_cond_dim is not None: + self.global_cond_embedder = nn.Sequential( + operations.Linear(global_cond_dim, dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(dim, dim * 6, bias=True, dtype=dtype, device=device), + ) + + # When using shared embed, TransformerBlocks use per-block Parameter (not per-block MLP) + block_global_cond_dim = None if global_cond_shared_embed else global_cond_dim for i in range(depth): self.layers.append( @@ -586,7 +727,9 @@ class ContinuousTransformer(nn.Module): dim_heads = dim_heads, cross_attend = cross_attend, dim_context = cond_token_dim, - global_cond_dim = global_cond_dim, + global_cond_dim = block_global_cond_dim, + global_cond_shared_embed = global_cond_shared_embed, + local_add_cond_dim = local_add_cond_dim, causal = causal, zero_init_branch_outputs = zero_init_branch_outputs, conformer=conformer, @@ -605,6 +748,7 @@ class ContinuousTransformer(nn.Module): prepend_embeds = None, prepend_mask = None, global_cond = None, + local_add_cond = None, return_info = False, **kwargs ): @@ -632,7 +776,9 @@ class ContinuousTransformer(nn.Module): mask = torch.cat((prepend_mask, mask), dim = -1) - # Attention layers + if self.num_memory_tokens > 0: + memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) if self.rotary_pos_emb is not None: rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device) @@ -642,6 +788,10 @@ class ContinuousTransformer(nn.Module): if self.use_sinusoidal_emb or self.use_abs_pos_emb: x = x + self.pos_emb(x) + # Project global_cond once (SA3 shared-embed path) + if global_cond is not None and self.global_cond_embedder is not None: + global_cond = self.global_cond_embedder(global_cond) + blocks_replace = patches_replace.get("dit", {}) # Iterate over the transformer layers for i, layer in enumerate(self.layers): @@ -654,12 +804,17 @@ class ContinuousTransformer(nn.Module): out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) - # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + x = layer(x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, + local_add_cond=local_add_cond, context=context, + transformer_options=transformer_options) if return_info: info["hidden_states"].append(x) + # Strip memory tokens before projecting out + if self.num_memory_tokens > 0: + x = x[:, self.num_memory_tokens:, :] + x = self.project_out(x) if return_info: @@ -682,6 +837,7 @@ class AudioDiffusionTransformer(nn.Module): num_heads=24, transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_features_type: str = "learned", audio_model="", dtype=None, device=None, @@ -696,7 +852,10 @@ class AudioDiffusionTransformer(nn.Module): # Timestep embeddings timestep_features_dim = 256 - self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) + if timestep_features_type == "expo": + self.timestep_features = ExpoFourierFeatures(timestep_features_dim, 0.5, 10000.0) + else: + self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) self.to_timestep_embed = nn.Sequential( operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device), @@ -781,6 +940,7 @@ class AudioDiffusionTransformer(nn.Module): cross_attn_cond=None, cross_attn_cond_mask=None, input_concat_cond=None, + local_add_cond=None, global_embed=None, prepend_cond=None, prepend_cond_mask=None, @@ -802,9 +962,13 @@ class AudioDiffusionTransformer(nn.Module): prepend_cond = self.to_prepend_embed(prepend_cond) prepend_inputs = prepend_cond + prepend_length = prepend_cond.shape[1] if prepend_cond_mask is not None: prepend_mask = prepend_cond_mask + if local_add_cond is not None and local_add_cond.dim() == 3: + local_add_cond = local_add_cond.permute(0, 2, 1) + if input_concat_cond is not None: # Interpolate input_concat_cond to the same length as x @@ -850,7 +1014,7 @@ class AudioDiffusionTransformer(nn.Module): if self.transformer_type == "x-transformers": output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) elif self.transformer_type == "continuous_transformer": - output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, local_add_cond=local_add_cond, **extra_args, **kwargs) if return_info: output, info = output @@ -876,6 +1040,7 @@ class AudioDiffusionTransformer(nn.Module): context=None, context_mask=None, input_concat_cond=None, + local_add_cond=None, global_embed=None, negative_global_embed=None, prepend_cond=None, @@ -890,6 +1055,7 @@ class AudioDiffusionTransformer(nn.Module): cross_attn_cond=context, cross_attn_cond_mask=context_mask, input_concat_cond=input_concat_cond, + local_add_cond=local_add_cond, global_embed=global_embed, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, diff --git a/comfy/ldm/audio/embedders.py b/comfy/ldm/audio/embedders.py index 20edb365a..ba9a62837 100644 --- a/comfy/ldm/audio/embedders.py +++ b/comfy/ldm/audio/embedders.py @@ -31,15 +31,39 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: ) +class ExpoFourierFeatures(nn.Module): + """Exponentially-spaced Fourier features (no learnable parameters).""" + def __init__(self, dim, min_freq=0.5, max_freq=10000.0): + super().__init__() + self.dim = dim + self.min_freq = min_freq + self.max_freq = max_freq + + def forward(self, t): + in_dtype = t.dtype + t = t.float() + if t.dim() == 1: + t = t.unsqueeze(-1) + half_dim = self.dim // 2 + ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32) + freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq)) + args = t * freqs * 2 * math.pi + return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype) + + class NumberEmbedder(nn.Module): def __init__( self, features: int, dim: int = 256, + fourier_features_type="learned", ): super().__init__() self.features = features - self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + if fourier_features_type == "expo": + self.embedding = nn.Sequential(ExpoFourierFeatures(dim=dim), comfy.ops.manual_cast.Linear(in_features=dim, out_features=features)) + else: + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) def forward(self, x: Union[List[float], Tensor]) -> Tensor: if not torch.is_tensor(x): @@ -77,14 +101,15 @@ class NumberConditioner(Conditioner): def __init__(self, output_dim: int, min_val: float=0, - max_val: float=1 + max_val: float=1, + fourier_features_type: str = "learned", ): super().__init__(output_dim, output_dim) self.min_val = min_val self.max_val = max_val - self.embedder = NumberEmbedder(features=output_dim) + self.embedder = NumberEmbedder(features=output_dim, fourier_features_type=fourier_features_type) def forward(self, floats, device=None): # Cast the inputs to floats diff --git a/comfy/ldm/audio/vae_sa3.py b/comfy/ldm/audio/vae_sa3.py new file mode 100644 index 000000000..276846444 --- /dev/null +++ b/comfy/ldm/audio/vae_sa3.py @@ -0,0 +1,533 @@ +import torch +import torch.nn as nn + +import comfy.ops +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.audio.autoencoder import WNConv1d + +ops = comfy.ops.disable_weight_init + +class Transpose(nn.Module): + def forward(self, x, **kwargs): + return x.transpose(-2, -1) + + +def _zero_pad_modulo_sequence(x, size, dim=-2): + input_len = x.shape[dim] + pad_len = (size - input_len % size) % size + if pad_len > 0: + pad_shape = list(x.shape) + pad_shape[dim] = pad_len + x = torch.cat([x, torch.zeros(pad_shape, device=x.device, dtype=x.dtype)], dim=dim) + return x + + +def _sliding_window_mask(seq_len, window, device, dtype): + """Additive attention mask enforcing a ±window local window (matches flash_attn window_size).""" + i = torch.arange(seq_len, device=device).unsqueeze(1) + j = torch.arange(seq_len, device=device).unsqueeze(0) + out_of_window = (j - i).abs() > window + return torch.where( + out_of_window, + torch.full((1,), torch.finfo(dtype).min / 4, device=device, dtype=dtype), + torch.zeros(1, device=device, dtype=dtype), + ) + + +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=4.0, dtype=None, device=None, **kwargs): + super().__init__() + self.alpha = nn.Parameter(torch.empty(1, dtype=dtype, device=device)) + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + def forward(self, x): + alpha = comfy.ops.cast_to_input(self.alpha, x) + gamma = comfy.ops.cast_to_input(self.gamma, x) + beta = comfy.ops.cast_to_input(self.beta, x) + return gamma * torch.tanh(alpha * x) + beta + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, base_rescale_factor=1., dtype=None, device=None): + super().__init__() + base = base * base_rescale_factor ** (dim / (dim - 2)) + self.register_buffer("inv_freq", torch.empty(dim // 2, dtype=dtype, device=device)) + + def forward_from_seq_len(self, seq_len, device, dtype=None): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + return self.forward(t) + + def forward(self, t): + freqs = torch.outer(t.float(), comfy.model_management.cast_to(self.inv_freq, dtype=torch.float32, device=t.device)) + freqs = torch.cat((freqs, freqs), dim=-1) + return freqs, 1. + + +def _rotate_half(x): + d = x.shape[-1] // 2 + return torch.cat((-x[..., d:], x[..., :d]), dim=-1) + + +def _apply_rotary_pos_emb(t, freqs): + out_dtype = t.dtype + rot_dim = freqs.shape[-1] + seq_len = t.shape[-2] + freqs = freqs[-seq_len:] + t_rot, t_pass = t[..., :rot_dim], t[..., rot_dim:] + t_rot = t_rot * freqs.cos() + _rotate_half(t_rot) * freqs.sin() + return torch.cat((t_rot.to(out_dtype), t_pass.to(out_dtype)), dim=-1) + + +class Attention(nn.Module): + def __init__(self, dim, dim_heads=64, qk_norm="none", qk_norm_eps=1e-6, + differential=False, zero_init_output=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.num_heads = dim // dim_heads + self.differential = differential + self.qk_norm = qk_norm + + self.to_qkv = operations.Linear( + dim, dim * (5 if differential else 3), bias=False, dtype=dtype, device=device) + self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + if qk_norm == "dyt": + self.q_norm = DynamicTanh(dim_heads, dtype=dtype, device=device) + self.k_norm = DynamicTanh(dim_heads, dtype=dtype, device=device) + elif qk_norm == "rms": + self.q_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device) + + def forward(self, x, rotary_pos_emb=None, mask=None, **kwargs): + B, N, _ = x.shape + h = self.num_heads + + qkv = self.to_qkv(x) + if self.differential: + q, k, v, q_diff, k_diff = qkv.chunk(5, dim=-1) + del qkv + q = q.view(B, N, h, -1).transpose(1, 2) + k = k.view(B, N, h, -1).transpose(1, 2) + v = v.view(B, N, h, -1).transpose(1, 2) + q_diff = q_diff.view(B, N, h, -1).transpose(1, 2) + k_diff = k_diff.view(B, N, h, -1).transpose(1, 2) + else: + q, k, v = qkv.chunk(3, dim=-1) + del qkv + q = q.view(B, N, h, -1).transpose(1, 2) + k = k.view(B, N, h, -1).transpose(1, 2) + v = v.view(B, N, h, -1).transpose(1, 2) + + if self.qk_norm != "none": + q_dtype, k_dtype = q.dtype, k.dtype + q = self.q_norm(q).to(q_dtype) + k = self.k_norm(k).to(k_dtype) + if self.differential: + q_diff = self.q_norm(q_diff).to(q_dtype) + k_diff = self.k_norm(k_diff).to(k_dtype) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype, k_dtype = q.dtype, k.dtype + q = _apply_rotary_pos_emb(q.float(), freqs).to(q_dtype) + k = _apply_rotary_pos_emb(k.float(), freqs).to(k_dtype) + if self.differential: + q_diff = _apply_rotary_pos_emb(q_diff.float(), freqs).to(q_dtype) + k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) + + if self.differential: + out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) + del q, k, v, q_diff, k_diff + else: + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + del q, k, v + + return self.to_out(out) + + +class _Sin(nn.Module): + def forward(self, x): + return torch.sin(3.14159265359 * x) + + +class _GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation, dtype=None, device=None, operations=None): + super().__init__() + self.act = activation + self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=4, no_bias=False, zero_init_output=True, + sinusoidal=False, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + inner_dim = int(dim * mult) + act = _Sin() if sinusoidal else nn.SiLU() + self.ff = nn.Sequential( + _GLU(dim, inner_dim, act, dtype=dtype, device=device, operations=operations), + nn.Identity(), + operations.Linear(inner_dim, dim, bias=not no_bias, dtype=dtype, device=device), + nn.Identity(), + ) + + def forward(self, x, **kwargs): + return self.ff(x) + + +class TransformerBlock(nn.Module): + def __init__(self, dim, dim_heads=64, causal=False, zero_init_branch_outputs=True, + norm_type="dyt", add_rope=False, attn_kwargs=None, ff_kwargs=None, + norm_kwargs=None, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if attn_kwargs is None: + attn_kwargs = {} + if ff_kwargs is None: + ff_kwargs = {} + if norm_kwargs is None: + norm_kwargs = {} + dim_heads = min(dim_heads, dim) + + Norm = DynamicTanh if norm_type == "dyt" else operations.RMSNorm + norm_kw = {**norm_kwargs, "dtype": dtype, "device": device} + + self.pre_norm = Norm(dim, **norm_kw) + self.self_attn = Attention(dim, dim_heads=dim_heads, + zero_init_output=zero_init_branch_outputs, + dtype=dtype, device=device, operations=operations, + **attn_kwargs) + self.ff_norm = Norm(dim, **norm_kw) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, + dtype=dtype, device=device, operations=operations, **ff_kwargs) + self.rope = RotaryEmbedding(dim_heads // 2, dtype=dtype, device=device) if add_rope else None + + def forward(self, x, mask=None, **kwargs): + rope = self.rope.forward_from_seq_len(x.shape[-2], device=x.device) \ + if self.rope is not None else None + x = x + self.self_attn(self.pre_norm(x), rotary_pos_emb=rope, mask=mask) + x = x + self.ff(self.ff_norm(x)) + return x + + +class TransformerResamplingBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, type="encoder", + transformer_depth=3, dim_heads=128, differential=True, + sliding_window=None, chunk_size=128, chunk_midpoint_shift=False, + dyt=True, ff_mult=3, mapping_bias=True, variable_stride=False, + sinusoidal_blocks=0, conv_mapping=False, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if type not in ("encoder", "decoder"): + raise ValueError(f"type must be 'encoder' or 'decoder', got {type!r}") + + self.type = type + self.stride = stride + self.chunk_size = chunk_size + self.chunk_midpoint_shift = chunk_midpoint_shift + self.variable_stride = variable_stride + self.transformer_depth = transformer_depth + + transformer_dim = out_channels if type == "encoder" else in_channels + + self.mapping = (WNConv1d(in_channels, out_channels, 3 if conv_mapping else 1, padding="same", bias=mapping_bias) + if in_channels != out_channels else nn.Identity()) + + self.sliding_window_latents = sliding_window + self.sliding_window_seq = self._get_sliding_window_size(sliding_window, stride) + self.input_seg_size, self.output_seg_size, self.sub_chunk_size = self._get_seg_sizes(stride) + + token_seq = 1 if variable_stride else self.output_seg_size + self.new_tokens = nn.Parameter(torch.empty(1, token_seq, transformer_dim, dtype=dtype, device=device)) + + norm_type = "dyt" if dyt else "rms_norm" + attn_kwargs = {"qk_norm": "dyt" if dyt else "rms", "qk_norm_eps": 1e-3, + "differential": differential} + norm_kwargs = {"eps": 1e-3} + transformers = [] + for i in range(transformer_depth): + sinusoidal = (transformer_depth - i) < sinusoidal_blocks + transformers.append(TransformerBlock( + transformer_dim, + dim_heads=dim_heads, + causal=False, + zero_init_branch_outputs=True, + norm_type=norm_type, + add_rope=True, + attn_kwargs=attn_kwargs, + ff_kwargs={"mult": ff_mult, "no_bias": False, "sinusoidal": sinusoidal}, + norm_kwargs=norm_kwargs, + dtype=dtype, device=device, operations=operations, + )) + self.transformers = nn.ModuleList(transformers) + + def _get_sliding_window_size(self, window, stride, prepend_cond_length=0): + if window is None: + return None + return [w * (stride + 1 + prepend_cond_length) for w in window] + + def _get_seg_sizes(self, stride, prepend_cond_length=0): + sub_chunk_size = stride + 1 + prepend_cond_length + input_seg_size = stride if self.type == "encoder" else 1 + output_seg_size = 1 if self.type == "encoder" else stride + return input_seg_size, output_seg_size, sub_chunk_size + + def forward(self, x, stride=None, **kwargs): + B = x.shape[0] + + if stride is None: + input_seg = self.input_seg_size + output_seg = self.output_seg_size + sub_chunk = self.sub_chunk_size + sliding_window = self.sliding_window_seq + else: + input_seg, output_seg, sub_chunk = self._get_seg_sizes(stride) + sliding_window = self._get_sliding_window_size(self.sliding_window_latents, stride) + + if self.type == "encoder": + if self.transformer_depth > 0: + pad_mod = self.chunk_size if sliding_window is None else input_seg + x = _zero_pad_modulo_sequence(x, pad_mod, dim=-1) + x = self.mapping(x) + + if self.transformer_depth > 0: + x = x.permute(0, 2, 1) + + if self.type != "encoder": + pad_mod = 1 if sliding_window is not None else ( + self.chunk_size // (stride if stride is not None else self.stride)) + x = _zero_pad_modulo_sequence(x, pad_mod) + + C = x.shape[2] + x = x.reshape(-1, input_seg, C) + + new_tokens = self.new_tokens.expand(x.shape[0], output_seg, -1) + x = torch.cat([x, comfy.ops.cast_to_input(new_tokens, x)], dim=-2) + del new_tokens + + x = x.reshape(B, -1, C) + + if sliding_window is None: + eff_chunk = self.chunk_size + self.chunk_size // (stride if stride is not None else self.stride) + + if sliding_window is None and self.chunk_midpoint_shift: + split = self.transformer_depth // 2 + shift = eff_chunk // 2 + + x = x.reshape(-1, eff_chunk, C) + for layer in self.transformers[:split]: + x = layer(x) + x = x.reshape(B, -1, C) + + shifted = torch.cat([x[:, :shift, :], x, x[:, -shift:, :]], dim=1) + del x + x = shifted.reshape(-1, eff_chunk, C) + del shifted + for layer in self.transformers[split:]: + x = layer(x) + x = x.reshape(B, -1, C) + x = x[:, shift:-shift, :] + elif sliding_window is None: + x = x.reshape(-1, eff_chunk, C) + for layer in self.transformers: + x = layer(x) + x = x.reshape(B, -1, C) + else: + attn_mask = _sliding_window_mask(x.shape[1], sliding_window[0], x.device, x.dtype) + for layer in self.transformers: + x = layer(x, mask=attn_mask) + + x = x.reshape(-1, sub_chunk, C) + x = x[:, -output_seg:, :] + x = x.reshape(B, -1, C).transpose(1, 2) + + if self.type == "decoder": + x = self.mapping(x) + + return x + + +class SAMEEncoder(nn.Module): + def __init__(self, in_channels=2, channels=128, latent_dim=32, + c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8), + transformer_depths=(3, 3, 3, 3), + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + channel_dims = [in_channels] + [channels * c for c in c_mults] + layers = [] + for i in range(len(c_mults)): + layers.append(TransformerResamplingBlock( + in_channels=channel_dims[i], out_channels=channel_dims[i + 1], + stride=strides[i], type="encoder", + transformer_depth=transformer_depths[i], + dtype=dtype, device=device, operations=operations, **kwargs)) + layers += [ + Transpose(), + operations.Linear(channel_dims[-1], latent_dim, dtype=dtype, device=device), + Transpose(), + ] + self.layers = nn.ModuleList(layers) + + def forward(self, x, **kwargs): + for layer in self.layers: + x = layer(x) + return x + + +class SAMEDecoder(nn.Module): + def __init__(self, out_channels=2, channels=128, latent_dim=32, + c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8), + transformer_depths=(3, 3, 3, 3), sinusoidal_blocks=None, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if sinusoidal_blocks is None: + sinusoidal_blocks = [0] * len(c_mults) + channel_dims = [out_channels] + [channels * c for c in c_mults] + layers = [ + Transpose(), + operations.Linear(latent_dim, channel_dims[-1], dtype=dtype, device=device), + Transpose(), + ] + for i in range(len(c_mults) - 1, -1, -1): + layers.append(TransformerResamplingBlock( + in_channels=channel_dims[i + 1], out_channels=channel_dims[i], + stride=strides[i], type="decoder", + transformer_depth=transformer_depths[i], + sinusoidal_blocks=sinusoidal_blocks[i], + dtype=dtype, device=device, operations=operations, **kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x, **kwargs): + for layer in self.layers: + x = layer(x) + return x + + +class SoftNormBottleneck(nn.Module): + def __init__(self, dim=32, noise_augment_dim=0, noise_regularize=False, + auto_scale=False, freeze=False, dtype=None, device=None, **kwargs): + super().__init__() + self.noise_augment_dim = noise_augment_dim + self.noise_regularize = noise_regularize + self.scaling_factor = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device)) + self.bias = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device)) + self.noise_scaling_factor = nn.Parameter(torch.empty(1, noise_augment_dim, 1, dtype=dtype, device=device)) + if auto_scale: + self.register_parameter("running_std", nn.Parameter( + torch.empty(1, dtype=dtype, device=device), requires_grad=False)) + if freeze: + for p in self.parameters(): + p.requires_grad = False + + def encode(self, x, return_info=False, **kwargs): + x = x * comfy.ops.cast_to_input(self.scaling_factor, x) \ + + comfy.ops.cast_to_input(self.bias, x) + if hasattr(self, "running_std"): + x = x / comfy.ops.cast_to_input(self.running_std, x) + if return_info: + return x, {} + return x + + def decode(self, x, **kwargs): + if hasattr(self, "running_std"): + x = x * comfy.ops.cast_to_input(self.running_std, x) + if self.noise_regularize: + scaling = self.running_std if hasattr(self, "running_std") \ + else x.std(dim=-1, keepdim=True) + noise = torch.randn_like(x) * comfy.ops.cast_to_input(scaling, x) * 1e-3 + x = x + noise + if self.noise_augment_dim > 0: + noise = comfy.ops.cast_to_input(self.noise_scaling_factor, x) * torch.randn( + x.shape[0], self.noise_augment_dim, x.shape[-1], device=x.device, dtype=x.dtype) + x = torch.cat([x, noise], dim=1) + return x + + +class PatchedPretransform(nn.Module): + def __init__(self, channels, patch_size, **kwargs): + super().__init__() + self.channels = channels + self.patch_size = patch_size + self.enable_grad = False + + def _pad(self, x): + pad_len = (self.patch_size - x.shape[-1] % self.patch_size) % self.patch_size + if pad_len > 0: + x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) + return x + + def encode(self, x): + x = self._pad(x) + B, C, T = x.shape + h = self.patch_size + L = T // h + # b c (l h) -> b (c h) l + return x.reshape(B, C, L, h).permute(0, 1, 3, 2).reshape(B, C * h, L) + + def decode(self, x): + B, Ch, L = x.shape + h = self.patch_size + C = Ch // h + # b (c h) l -> b c (l h) + return x.reshape(B, C, h, L).permute(0, 1, 3, 2).reshape(B, C, L * h) + + +class SA3AudioVAE(nn.Module): + """SA3 VAE. State dict keys match checkpoint after stripping 'pretransform.model.'""" + + def __init__(self, channels=256, transformer_depths=12, sinusoidal_blocks=8, + sliding_window=None, decoder_conv_mapping=False, + chunk_size=128, chunk_midpoint_shift=False, + dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = ops + + self.pretransform = PatchedPretransform(channels=2, patch_size=256) + + common_kwargs = dict( + differential=True, dyt=True, dim_heads=64, + sliding_window=sliding_window, variable_stride=True, + chunk_size=chunk_size, chunk_midpoint_shift=chunk_midpoint_shift, + dtype=dtype, device=device, operations=operations, + ) + self.encoder = SAMEEncoder( + in_channels=512, channels=channels, c_mults=[6], strides=[16], + latent_dim=256, transformer_depths=[transformer_depths], + conv_mapping=False, **common_kwargs, + ) + self.decoder = SAMEDecoder( + out_channels=512, channels=channels, c_mults=[6], strides=[16], + latent_dim=256, transformer_depths=[transformer_depths], sinusoidal_blocks=[sinusoidal_blocks], + conv_mapping=decoder_conv_mapping, **common_kwargs, + ) + self.bottleneck = SoftNormBottleneck( + dim=256, noise_augment_dim=0, noise_regularize=True, + auto_scale=True, freeze=True, + dtype=dtype, device=device, + ) + + @torch.no_grad() + def _pretransform_encode(self, x): + return self.pretransform.encode(x) + + @torch.no_grad() + def _pretransform_decode(self, x): + return self.pretransform.decode(x) + + def encode(self, x): + x = self._pretransform_encode(x) + x = self.encoder(x) + x = self.bottleneck.encode(x) + return x + + def decode(self, x): + x = self.bottleneck.decode(x) + x = self.decoder(x) + x = self._pretransform_decode(x) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index c22705655..d81f13c69 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -813,6 +813,85 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd +class StableAudio3(BaseModel): + def __init__(self, model_config, seconds_total_embedder_weights, padding_embedding=None, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) + self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=384, fourier_features_type=model_config.unet_config["timestep_features_type"]) + self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) + if padding_embedding is not None: + self.padding_embedding = torch.nn.Parameter(padding_embedding, requires_grad=False) + else: + self.padding_embedding = None + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = self.process_latent_in(image) + # TODO: scale if not match + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + if mask.shape[1] != 1: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + # TODO: scale if not match + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((mask, image), dim=1) + + def extra_conds(self, **kwargs): + out = {} + + concat_cond = self.concat_cond(**kwargs) + if concat_cond is not None: + out['local_add_cond'] = comfy.conds.CONDNoiseShape(concat_cond) + + noise = kwargs.get("noise", None) + device = kwargs["device"] + + seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 10.7666)) + seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device) + + global_embed = seconds_total_embed.reshape((1, -1)) + out['global_embed'] = comfy.conds.CONDRegular(global_embed) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + cross_attn = cross_attn.to(device) + if self.padding_embedding is not None: + pe = self.padding_embedding.to(device=device, dtype=cross_attn.dtype) + max_text_tokens = self.model_config.unet_config.get("max_text_tokens", 256) + n_text = cross_attn.shape[1] + if n_text < max_text_tokens: + pad = pe.view(1, 1, -1).expand(cross_attn.shape[0], max_text_tokens - n_text, -1) + cross_attn = torch.cat([cross_attn, pad], dim=1) + cross_attn = torch.cat([cross_attn, seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + return out + + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + + d = {"conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} + + for k in d: + s = d[k] + for l in s: + sd["{}{}".format(k, l)] = s[l] + + if self.padding_embedding is not None: + sd["conditioner.conditioners.prompt.padding_embedding"] = self.padding_embedding.data + return sd + class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bc0b933bc..70b4df8b3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -116,6 +116,45 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" + unet_config["global_cond_dim"] = state_dict['{}to_global_embed.0.weight'.format(key_prefix)].shape[1] + cond_embed = state_dict['{}to_cond_embed.0.weight'.format(key_prefix)] + unet_config["project_cond_tokens"] = cond_embed.shape[0] != cond_embed.shape[1] + unet_config["embed_dim"] = state_dict['{}to_timestep_embed.0.weight'.format(key_prefix)].shape[0] + mem_tokens = state_dict.get('{}transformer.memory_tokens'.format(key_prefix), None) + to_qkv = state_dict.get('{}transformer.layers.0.self_attn.to_qkv.weight'.format(key_prefix), None) + differential = False + if to_qkv is not None: + if to_qkv.shape[0] == to_qkv.shape[1] * 5: + differential = True + if mem_tokens is not None: + unet_config["num_memory_tokens"] = mem_tokens.shape[0] + if '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) in state_dict: + unet_config["attn_kwargs"] = {"qk_norm": "ln", "feat_scale": True} + rms_norm = state_dict.get('{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix), None) + if rms_norm is not None: + unet_config["attn_kwargs"] = {"qk_norm": "rms", "differential": differential} + unet_config["norm_type"] = "rms_norm" + unet_config["num_heads"] = unet_config["embed_dim"] // rms_norm.shape[0] + + if '{}timestep_features.weight'.format(key_prefix) in state_dict: + unet_config["timestep_features_type"] = "learned" + else: + unet_config["timestep_features_type"] = "expo" + + io_channels = state_dict['{}postprocess_conv.weight'.format(key_prefix)].shape[0] + unet_config["io_channels"] = io_channels + unet_config["input_concat_dim"] = state_dict['{}transformer.project_in.weight'.format(key_prefix)].shape[1] - io_channels + + local_add_cond = state_dict.get('{}transformer.layers.0.to_local_embed.0.weight'.format(key_prefix), None) + if local_add_cond is not None: + unet_config["local_add_cond_dim"] = local_add_cond.shape[1] + + global_cond_embed = state_dict.get('{}transformer.global_cond_embedder.0.weight'.format(key_prefix), None) + if global_cond_embed is not None: + unet_config["global_cond_shared_embed"] = True + unet_config["global_cond_type"] = "adaLN" + + unet_config["depth"] = count_blocks(state_dict_keys, '{}transformer.layers.'.format(key_prefix) + '{}.') return unet_config if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit diff --git a/comfy/sd.py b/comfy/sd.py index 2443353a4..7bd07ed3a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -21,6 +21,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder +import comfy.ldm.audio.vae_sa3 import comfy.pixel_space_convert import comfy.weight_adapter import yaml @@ -67,6 +68,7 @@ import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo +import comfy.text_encoders.sa3 import comfy.model_patcher import comfy.lora @@ -854,6 +856,34 @@ class VAE: self.working_dtypes = [torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "decoder.layers.3.transformers.0.pre_norm.alpha" in sd: # Stable Audio 3 VAE + if "decoder.layers.3.transformers.11.self_attn.to_out.weight" in sd: + config = {"channels": 256, "transformer_depths": 12, "sinusoidal_blocks": 8, + "sliding_window": [1, 1], "decoder_conv_mapping": False, + "chunk_size": 128, "chunk_midpoint_shift": False} + self.memory_used_encode = lambda shape, dtype: (1500 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * 4096) * model_management.dtype_size(dtype) + else: + config = {"channels": 128, "transformer_depths": 6, "sinusoidal_blocks": 0, + "sliding_window": None, "decoder_conv_mapping": True, + "chunk_size": 32, "chunk_midpoint_shift": True} + self.memory_used_encode = lambda shape, dtype: (72 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (72 * shape[2] * 4096) * model_management.dtype_size(dtype) + + self.first_stage_model = comfy.ldm.audio.vae_sa3.SA3AudioVAE(**config) + self.latent_channels = 256 + self.output_channels = 2 + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 1 + self.audio_sample_rate = 44100 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + #This VAE has Parameters and Buffers the non-dynamic caster cannot handle + #Force cast it for --disable-dynamic-vram users until there is a true core fix. + if not comfy.memory_management.aimdo_enabled: + self.disable_offload = True else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -1290,6 +1320,7 @@ class TEModel(Enum): GEMMA_4_E4B = 29 GEMMA_4_E2B = 30 GEMMA_4_31B = 31 + T5_GEMMA = 32 def detect_te_model(sd): @@ -1314,6 +1345,8 @@ def detect_te_model(sd): if weight.shape[0] == 384: return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE + if "model.encoder.layers.0.pre_self_attn_layernorm.weight" in sd: + return TEModel.T5_GEMMA if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.59.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_4_31B @@ -1463,6 +1496,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model == TEModel.T5_GEMMA: + clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel + clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1e4434fd5..617db4f28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sdxl_clip import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 +import comfy.text_encoders.sa3 import comfy.text_encoders.aura_t5 import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit @@ -603,6 +604,29 @@ class StableAudio(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model) +class StableAudio3(StableAudio): + unet_config = { + "audio_model": "dit1.0", + "global_cond_shared_embed": True, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 2.0, + } + + latent_format = latent_formats.StableAudio3 + + memory_usage_factor = 7 + + def get_model(self, state_dict, prefix="", device=None): + seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True) + padding_embedding = state_dict.get("conditioner.conditioners.prompt.padding_embedding", None) + return model_base.StableAudio3(self, seconds_total_embedder_weights=seconds_total_sd, padding_embedding=padding_embedding, device=device) + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.sa3.SAT5GemmaTokenizer, comfy.text_encoders.sa3.SAT5GemmaModel) + class AuraFlow(supported_models_base.BASE): unet_config = { "cond_seq_dim": 2048, @@ -2018,6 +2042,7 @@ models = [ SV3D_u, SV3D_p, SD3, + StableAudio3, StableAudio, AuraFlow, PixArtAlpha, diff --git a/comfy/text_encoders/sa3.py b/comfy/text_encoders/sa3.py new file mode 100644 index 000000000..0a1c73ec1 --- /dev/null +++ b/comfy/text_encoders/sa3.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +from comfy import sd1_clip +from comfy.text_encoders.llama import Attention as LlamaAttention, RMSNorm, MLP, precompute_freqs_cis, apply_rope, _make_scaled_embedding +from comfy.text_encoders.spiece_tokenizer import SPieceTokenizer + + +class T5GemmaEncoderConfig: + def __init__(self): + self.vocab_size = 256000 + self.hidden_size = 768 + self.intermediate_size = 2048 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.num_key_value_heads = 12 + self.head_dim = 64 + self.rms_norm_eps = 1e-6 + self.rms_norm_add = False + self.rope_theta = 10000.0 + self.attn_logit_softcapping = 50.0 + self.query_pre_attn_scalar = 64 + self.sliding_window = 4096 + self.mlp_activation = "gelu_pytorch_tanh" + self.layer_types = ["sliding_attention", "full_attention"] * 6 + self.qkv_bias = False + self.q_norm = None + self.k_norm = None + self.rms_norm_add = True + + +class T5GemmaAttention(LlamaAttention): + """Reuses LlamaAttention projection setup; overrides forward for softcap attention. + + T5Gemma applies tanh(QK^T * scale / cap) * cap between the matmul and softmax. + This nonlinearity is incompatible with fused SDPA kernels, so attention is + computed manually. Everything else (projections, RoPE, GQA expansion) is identical + to LlamaAttention so __init__ is inherited unchanged. + """ + + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__(config, device=device, dtype=dtype, ops=ops) + self.scale = config.query_pre_attn_scalar ** -0.5 + self.softcap = config.attn_logit_softcapping + + def forward(self, hidden_states, attention_mask=None, freqs_cis=None, **kwargs): + B, S, _ = hidden_states.shape + xq = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + xk = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + xv = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + xq, xk = apply_rope(xq, xk, freqs_cis) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + attn = torch.matmul(xq * self.scale, xk.transpose(-2, -1)) + attn = torch.tanh(attn / self.softcap) * self.softcap + if attention_mask is not None: + attn = attn + attention_mask + attn = torch.nn.functional.softmax(attn.float(), dim=-1).to(xq.dtype) + out = torch.matmul(attn, xv).transpose(1, 2).reshape(B, S, self.inner_size) + return self.o_proj(out), None + + +class T5GemmaBlock(nn.Module): + def __init__(self, config, layer_type, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = T5GemmaAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + # Names match checkpoint keys: model.encoder.layers.X..weight + self.pre_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.is_sliding = (layer_type == "sliding_attention") + self.sliding_window = config.sliding_window + + def forward(self, x, attention_mask=None, freqs_cis=None): + attn_mask = attention_mask + if self.is_sliding and x.shape[1] > self.sliding_window: + S = x.shape[1] + pos = torch.arange(S, device=x.device) + dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() + sw_mask = torch.zeros(S, S, dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(dist > self.sliding_window, -torch.finfo(x.dtype).max) + sw_mask = sw_mask.unsqueeze(0).unsqueeze(0) + attn_mask = (attention_mask + sw_mask) if attention_mask is not None else sw_mask + residual = x + x = self.pre_self_attn_layernorm(x) + x, _ = self.self_attn(x, attention_mask=attn_mask, freqs_cis=freqs_cis) + x = self.post_self_attn_layernorm(x) + x = residual + x + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class T5GemmaEncoder(nn.Module): + """Encoder stack: embed_tokens, layers, norm. + Keys: embed_tokens.*, layers.X.*, norm.*""" + + def __init__(self, config, device, dtype, ops): + super().__init__() + self.config = config + # Gemma-style scaled embedding: output *= sqrt(hidden_size) + self.embed_tokens = _make_scaled_embedding( + ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) + self.layers = nn.ModuleList([ + T5GemmaBlock(config, config.layer_types[i], device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + + def forward(self, input_ids, attention_mask=None, embeds=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, num_layers=None): + x = embeds if embeds is not None else self.embed_tokens(input_ids, out_dtype=dtype or torch.float32) + seq_len = x.shape[1] + position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, device=x.device) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + intermediate = None + for i, layer in enumerate(self.layers): + x = layer(x, attention_mask=mask, freqs_cis=freqs_cis) + if i == intermediate_output: + intermediate = x.clone() + x = self.norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.norm(intermediate) + return x, intermediate + + +class T5GemmaBody(nn.Module): + """Provides the 'encoder' sub-module. + Keys: encoder.*""" + + def __init__(self, config, device, dtype, ops): + super().__init__() + self.encoder = T5GemmaEncoder(config, device, dtype, ops) + + +class T5GemmaModel(nn.Module): + """Top-level model class passed to SDClipModel as model_class. + Module layout: self.model.encoder.* → matches checkpoint keys model.encoder.*""" + + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = T5GemmaEncoderConfig() + self.num_layers = config.num_hidden_layers + self.dtype = dtype + self.model = T5GemmaBody(config, device, dtype, operations) + + def get_input_embeddings(self): + return self.model.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.model.encoder.embed_tokens = embeddings + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, **kwargs): + if intermediate_output is not None and intermediate_output < 0: + intermediate_output = self.num_layers + intermediate_output + return self.model.encoder( + input_ids, attention_mask=attention_mask, embeds=embeds, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, num_layers=self.num_layers) + + +class T5GemmaSDClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"pad": 0}, + model_class=T5GemmaModel, + enable_attention_masks=True, zero_out_masked=True, + model_options=model_options) + + +class T5GemmaSDTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_model = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer_model, pad_with_end=False, embedding_size=768, + embedding_key="t5gemma", tokenizer_class=SPieceTokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, + max_length=99999999, min_length=1, pad_token=0, + tokenizer_data=tokenizer_data, + tokenizer_args={"add_bos": False, "add_eos": False}) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + +class SAT5GemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, clip_name="t5gemma", tokenizer=T5GemmaSDTokenizer) + + +class SAT5GemmaModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, + name="t5gemma", clip_model=T5GemmaSDClipModel, **kwargs)