From 55ebd287ee72f637615496d5de916b2645577f0c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 11 Apr 2026 18:06:36 -0700 Subject: [PATCH 1/4] Add a supports_fp64 function. (#13368) --- comfy/ldm/flux/math.py | 2 +- comfy/model_management.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 824daf5e6..6d0aed827 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): + if not comfy.model_management.supports_fp64(pos.device): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/model_management.py b/comfy/model_management.py index 0eebf1ded..bcf1399c4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1732,6 +1732,21 @@ def supports_mxfp8_compute(device=None): return True +def supports_fp64(device=None): + if is_device_mps(device): + return False + + if is_intel_xpu(): + return False + + if is_directml_enabled(): + return False + + if is_ixuca(): + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): From 31283d2892f54caf9bfdf6edb9c98cbfa88c5f0c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 11 Apr 2026 19:29:31 -0700 Subject: [PATCH 2/4] Implement Ernie Image model. (#13369) --- comfy/ldm/ernie/model.py | 303 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 12 ++ comfy/model_detection.py | 5 + comfy/sd.py | 8 + comfy/supported_models.py | 34 +++- comfy/text_encoders/ernie.py | 38 +++++ comfy/text_encoders/flux.py | 4 +- comfy/text_encoders/llama.py | 32 ++++ 8 files changed, 433 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/ernie/model.py create mode 100644 comfy/text_encoders/ernie.py diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py new file mode 100644 index 000000000..1f8f08376 --- /dev/null +++ b/comfy/ldm/ernie/model.py @@ -0,0 +1,303 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + if not comfy.model_management.supports_fp64(pos.device): + device = torch.device("cpu") + else: + device = pos.device + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) + return out.to(dtype=torch.float32, device=pos.device) + +def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = freqs_cis[0] + sin_ = freqs_cis[1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + +class ErnieImageEmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: tuple): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2] + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + +class ErnieImagePatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None): + super().__init__() + self.patch_size = patch_size + self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + batch_size, dim, height, width = x.shape + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool = False): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + if self.flip_sin_to_cos: + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + else: + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + return emb + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None): + super().__init__() + Linear = operations.Linear + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype) + self.act = nn.SiLU() + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + +class ErnieImageAttention(nn.Module): + def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None): + super().__init__() + self.heads = heads + self.head_dim = dim_head + self.inner_dim = heads * dim_head + + Linear = operations.Linear + RMSNorm = operations.RMSNorm + + self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) + self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) + self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) + + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype) + + self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)]) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor: + B, S, _ = x.shape + + q_flat = self.to_q(x) + k_flat = self.to_k(x) + v_flat = self.to_v(x) + + query = q_flat.view(B, S, self.heads, self.head_dim) + key = k_flat.view(B, S, self.heads, self.head_dim) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query, key = query.to(x.dtype), key.to(x.dtype) + + q_flat = query.reshape(B, S, -1) + k_flat = key.reshape(B, S, -1) + + hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask) + + return self.to_out[0](hidden_states) + +class ErnieImageFeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None): + super().__init__() + Linear = operations.Linear + self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype) + self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype) + self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + +class ErnieImageSharedAdaLNBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): + super().__init__() + RMSNorm = operations.RMSNorm + + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype) + self.self_attention = ErnieImageAttention( + query_dim=hidden_size, + dim_head=hidden_size // num_heads, + heads=num_heads, + eps=eps, + operations=operations, + device=device, + dtype=dtype + ) + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype) + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype) + + def forward(self, x, rotary_pos_emb, temb, attention_mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb + + residual = x + x_norm = self.adaLN_sa_ln(x) + x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + + attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) + x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) + + residual = x + x_norm = self.adaLN_mlp_ln(x) + x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + + return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype) + +class ErnieImageAdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): + super().__init__() + LayerNorm = operations.LayerNorm + Linear = operations.Linear + self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype) + self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + +class ErnieImageModel(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_layers: int = 36, + ffn_hidden_size: int = 12288, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 3072, + rope_theta: int = 256, + rope_axes_dim: tuple = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + device=None, + dtype=None, + operations=None, + **kwargs + ): + super().__init__() + self.dtype = dtype + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.patch_size = patch_size + self.out_channels = out_channels + + Linear = operations.Linear + + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype) + self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None + + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype) + + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype) + ) + + self.layers = nn.ModuleList([ + ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype) + for _ in range(num_layers) + ]) + + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype) + self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype) + + def forward(self, x, timesteps, context, **kwargs): + device, dtype = x.device, x.dtype + B, C, H, W = x.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + img_bsh = self.x_embedder(x) + + text_bth = context + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + + hidden_states = torch.cat([img_bsh, text_bth], dim=1) + + text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32) + text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32) + index = float(Tmax) + + transformer_options = kwargs.get("transformer_options", {}) + rope_options = transformer_options.get("rope_options", None) + + h_len, w_len = float(Hp), float(Wp) + h_offset, w_offset = 0.0, 0.0 + + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32) + image_ids[:, :, 0] = image_ids[:, :, 1] + index + image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1) + image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0) + + image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1) + + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) + del image_ids, text_ids + + sample = self.time_proj(timesteps.to(dtype)).to(self.time_embedding.linear_1.weight.dtype) + c = self.time_embedding(sample) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1) + ] + + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] + for layer in self.layers: + hidden_states = layer(hidden_states, rotary_pos_emb, temb) + + hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states) + + patches = self.final_linear(hidden_states)[:, :N_img, :] + output = ( + patches.view(B, Hp, Wp, p, p, self.out_channels) + .permute(0, 5, 1, 3, 2, 4) + .contiguous() + .view(B, self.out_channels, H, W) + ) + + return output diff --git a/comfy/model_base.py b/comfy/model_base.py index c2ae646aa..5c2668ba9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -53,6 +53,7 @@ import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 +import comfy.ldm.ernie.model import comfy.model_management import comfy.patcher_extension @@ -1962,3 +1963,14 @@ class Kandinsky5Image(Kandinsky5): class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class ErnieImage(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8bed6828d..ca06cdd1e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -713,6 +713,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config + if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image + dit_config = {} + dit_config["image_model"] = "ernie" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index f331feefb..e573804a5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,6 +62,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 +import comfy.text_encoders.ernie import comfy.model_patcher import comfy.lora @@ -1235,6 +1236,7 @@ class TEModel(Enum): QWEN35_4B = 25 QWEN35_9B = 26 QWEN35_27B = 27 + MINISTRAL_3_3B = 28 def detect_te_model(sd): @@ -1301,6 +1303,8 @@ def detect_te_model(sd): return TEModel.MISTRAL3_24B else: return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + if weight.shape[0] == 3072: + return TEModel.MINISTRAL_3_3B return TEModel.LLAMA3_8 return None @@ -1458,6 +1462,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer + elif te_model == TEModel.MINISTRAL_3_3B: + clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9a5612716..58d4ce731 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -26,6 +26,7 @@ import comfy.text_encoders.z_image import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image +import comfy.text_encoders.ernie from . import supported_models_base from . import latent_formats @@ -1749,6 +1750,37 @@ class RT_DETR_v4(supported_models_base.BASE): def clip_target(self, state_dict={}): return None -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] + +class ErnieImage(supported_models_base.BASE): + unet_config = { + "image_model": "ernie", + } + + sampling_settings = { + "multiplier": 1000.0, + "shift": 3.0, + } + + memory_usage_factor = 10.0 + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.ErnieImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/ernie.py b/comfy/text_encoders/ernie.py new file mode 100644 index 000000000..8c56c1c11 --- /dev/null +++ b/comfy/text_encoders/ernie.py @@ -0,0 +1,38 @@ +from .flux import Mistral3Tokenizer +from comfy import sd1_clip +import comfy.text_encoders.llama + +class Ministral3_3BTokenizer(Mistral3Tokenizer): + def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}): + return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data) + +class ErnieTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Ministral3_3BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ErnieTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ErnieTEModel_(ErnieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ErnieTEModel diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 1ae398789..d5eb91dcb 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -116,9 +116,9 @@ class MistralTokenizerClass: return LlamaTokenizerFast(**kwargs) class Mistral3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 06f2fbf74..6cdc47757 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -60,6 +60,29 @@ class Mistral3Small24BConfig: final_norm: bool = True lm_head: bool = False +@dataclass +class Ministral3_3BConfig: + vocab_size: int = 131072 + hidden_size: int = 3072 + intermediate_size: int = 9216 + num_hidden_layers: int = 26 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 262144 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen25_3BConfig: vocab_size: int = 151936 @@ -946,6 +969,15 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Ministral3_3B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ministral3_3BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() From 971932346ac6e6e02c1e1e8cfe34df2f0e1cea3e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:27:38 -0700 Subject: [PATCH 3/4] Update quant doc so it's not completely wrong. (#13381) There is still more that needs to be fixed. --- QUANTIZATION.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/QUANTIZATION.md b/QUANTIZATION.md index 1693e13f3..300822029 100644 --- a/QUANTIZATION.md +++ b/QUANTIZATION.md @@ -139,9 +139,9 @@ Example: "_quantization_metadata": { "format_version": "1.0", "layers": { - "model.layers.0.mlp.up_proj": "float8_e4m3fn", - "model.layers.0.mlp.down_proj": "float8_e4m3fn", - "model.layers.1.mlp.up_proj": "float8_e4m3fn" + "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"}, + "model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"}, + "model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"} } } } @@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s 3. **Compute scales**: Derive `input_scale` from collected statistics 4. **Store in checkpoint**: Save `input_scale` parameters alongside weights -The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. \ No newline at end of file +The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. From c2657d5fb9ccfae150c8e5d0e1b39780a0cc33e9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:37:13 -0700 Subject: [PATCH 4/4] Fix typo. (#13382) --- comfy/text_encoders/ernie.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/ernie.py b/comfy/text_encoders/ernie.py index 8c56c1c11..2c7df78fe 100644 --- a/comfy/text_encoders/ernie.py +++ b/comfy/text_encoders/ernie.py @@ -3,7 +3,7 @@ from comfy import sd1_clip import comfy.text_encoders.llama class Ministral3_3BTokenizer(Mistral3Tokenizer): - def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}): + def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}): return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data) class ErnieTokenizer(sd1_clip.SD1Tokenizer):