From fe2511468d6b3f7a1581a711ae480f3313f2b39d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:01:38 -0800 Subject: [PATCH] Support the 4B ace step 1.5 lm model. (#12257) Can be used as an alternative to the 1.7B --- comfy/sd.py | 7 +++- comfy/supported_models.py | 12 +++++- comfy/text_encoders/ace15.py | 42 ++++++++++++++++----- comfy/text_encoders/llama.py | 72 ++++++++++++++++++++++++++---------- 4 files changed, 101 insertions(+), 32 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index aa93e101d..bc63d6ced 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1444,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) elif clip_type == CLIPType.ACE: - clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data)) + te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])] + if TEModel.QWEN3_4B in te_models: + model_type = "qwen3_4b" + else: + model_type = "qwen3_2b" + clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6b7d831cb..77264ed28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] - hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect)) + detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if "dtype_llama" in detect_2b: + detect = detect_2b + detect["lm_model"] = "qwen3_2b" + elif "dtype_llama" in detect_4b: + detect = detect_4b + detect["lm_model"] = "qwen3_4b" + + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 48c29fa9a..73d710671 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -162,14 +162,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) +class Qwen3_4B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + class ACE15TEModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}): super().__init__() if dtype_llama is None: dtype_llama = dtype + model = None + self.constant = 0.4375 + if lm_model == "qwen3_4b": + model = Qwen3_4B_ACE15 + self.constant = 0.5625 + elif lm_model == "qwen3_2b": + model = Qwen3_2B_ACE15 + + self.lm_model = lm_model self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) - self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options) + if model is not None: + setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options)) + self.dtypes = set([dtype, dtype_llama]) def encode_token_weights(self, token_weight_pairs): @@ -182,17 +202,21 @@ class ACE15TEModel(torch.nn.Module): lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) lm_metadata = token_weight_pairs["lm_metadata"] - audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} def set_clip_options(self, options): self.qwen3_06b.set_clip_options(options) - self.qwen3_2b.set_clip_options(options) + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.set_clip_options(options) def reset_clip_options(self): self.qwen3_06b.reset_clip_options() - self.qwen3_2b.reset_clip_options() + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.reset_clip_options() def load_sd(self, sd): if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -200,11 +224,11 @@ class ACE15TEModel(torch.nn.Module): if shape[0] == 1024: return self.qwen3_06b.load_sd(sd) else: - return self.qwen3_2b.load_sd(sd) + return getattr(self, self.lm_model).load_sd(sd) def memory_estimation_function(self, token_weight_pairs, device=None): lm_metadata = token_weight_pairs["lm_metadata"] - constant = 0.4375 + constant = self.constant if comfy.model_management.should_use_bf16(device): constant *= 0.5 @@ -213,11 +237,11 @@ class ACE15TEModel(torch.nn.Module): num_tokens += lm_metadata['min_tokens'] return num_tokens * constant * 1024 * 1024 -def te(dtype_llama=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): class ACE15TEModel_(ACE15TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["llama_quantization_metadata"] = llama_quantization_metadata - super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options) + super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options) return ACE15TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index f4ac224c6..3afd094d1 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -150,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config: final_norm: bool = True lm_head: bool = False +@dataclass +class Qwen3_4B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -739,6 +762,21 @@ class BaseLlama: def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) +class BaseQwen3: + def logits(self, x): + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Llama2(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): @@ -767,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B(BaseLlama, torch.nn.Module): +class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06BConfig(**config_dict) @@ -776,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): +class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06B_ACE15_Config(**config_dict) @@ -785,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): +class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_2B_ACE15_lm_Config(**config_dict) @@ -794,22 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype - def logits(self, x): - input = x[:, -1:] - module = self.model.embed_tokens - - offload_stream = None - if module.comfy_cast_weights: - weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) - else: - weight = self.model.embed_tokens.weight.to(x) - - x = torch.nn.functional.linear(input, weight, None) - - comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) - return x - -class Qwen3_4B(BaseLlama, torch.nn.Module): +class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_4BConfig(**config_dict) @@ -818,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_8B(BaseLlama, torch.nn.Module): +class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_8BConfig(**config_dict)