From 40862c07760a33ebd2e54b16b640f9e9f51f8946 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:13:47 -0700 Subject: [PATCH] Support Ace Step 1.5 XL model. (#13317) --- comfy/ldm/ace/ace_step15.py | 18 +++++++++++------- comfy/model_detection.py | 9 +++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 1d7dc59a8..2ca2d26c4 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -611,6 +611,7 @@ class AceStepDiTModel(nn.Module): intermediate_size, patch_size, audio_acoustic_hidden_dim, + condition_dim=None, layer_types=None, sliding_window=128, rms_norm_eps=1e-6, @@ -640,7 +641,7 @@ class AceStepDiTModel(nn.Module): self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) - self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.condition_embedder = Linear(condition_dim, hidden_size, dtype=dtype, device=device) if layer_types is None: layer_types = ["full_attention"] * num_layers @@ -1035,6 +1036,9 @@ class AceStepConditionGenerationModel(nn.Module): fsq_dim=2048, fsq_levels=[8, 8, 8, 5, 5, 5], fsq_input_num_quantizers=1, + encoder_hidden_size=2048, + encoder_intermediate_size=6144, + encoder_num_heads=16, audio_model=None, dtype=None, device=None, @@ -1054,24 +1058,24 @@ class AceStepConditionGenerationModel(nn.Module): self.decoder = AceStepDiTModel( in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim, - intermediate_size, patch_size, audio_acoustic_hidden_dim, + intermediate_size, patch_size, audio_acoustic_hidden_dim, condition_dim=encoder_hidden_size, layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.encoder = AceStepConditionEncoder( - text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, - num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + text_hidden_dim, timbre_hidden_dim, encoder_hidden_size, num_lyric_layers, num_timbre_layers, + encoder_num_heads, num_kv_heads, head_dim, encoder_intermediate_size, rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.tokenizer = AceStepAudioTokenizer( - audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, + audio_acoustic_hidden_dim, encoder_hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, dtype=dtype, device=device, operations=operations ) self.detokenizer = AudioTokenDetokenizer( - hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, + encoder_hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, dtype=dtype, device=device, operations=operations ) - self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + self.null_condition_emb = nn.Parameter(torch.empty(1, 1, encoder_hidden_size, dtype=dtype, device=device)) def prepare_condition( self, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1c8ae2325..8bed6828d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -696,6 +696,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys: dit_config = {} dit_config["audio_model"] = "ace1.5" + head_dim = 128 + dit_config["hidden_size"] = state_dict['{}decoder.layers.0.self_attn_norm.weight'.format(key_prefix)].shape[0] + dit_config["intermediate_size"] = state_dict['{}decoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0] + dit_config["num_heads"] = state_dict['{}decoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim + + dit_config["encoder_hidden_size"] = state_dict['{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix)].shape[0] + dit_config["encoder_num_heads"] = state_dict['{}encoder.lyric_encoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim + dit_config["encoder_intermediate_size"] = state_dict['{}encoder.lyric_encoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0] + dit_config["num_dit_layers"] = count_blocks(state_dict_keys, '{}decoder.layers.'.format(key_prefix) + '{}.') return dit_config if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4