From 17027f2a6a20a31e2c6f3be2b1a06f39ad3a68d9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 18 Nov 2025 19:36:03 -0800 Subject: [PATCH] Add a way to disable the final norm in the llama based TE models. (#10794) --- comfy/text_encoders/llama.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index c050759fe..feb44bbb0 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -32,6 +32,7 @@ class Llama2Config: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass class Qwen25_3BConfig: @@ -53,6 +54,7 @@ class Qwen25_3BConfig: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass class Qwen25_7BVLI_Config: @@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass class Gemma2_2B_Config: @@ -96,6 +99,7 @@ class Gemma2_2B_Config: k_norm = None sliding_attention = None rope_scale = None + final_norm: bool = True @dataclass class Gemma3_4B_Config: @@ -118,6 +122,7 @@ class Gemma3_4B_Config: k_norm = "gemma3" sliding_attention = [False, False, False, False, False, 1024] rope_scale = [1.0, 8.0] + final_norm: bool = True class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -366,7 +371,12 @@ class Llama2_(nn.Module): transformer(config, index=i, device=device, dtype=dtype, ops=ops) for i in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + if config.final_norm: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.norm = None + # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): @@ -421,14 +431,16 @@ class Llama2_(nn.Module): if i == intermediate_output: intermediate = x.clone() - x = self.norm(x) + if self.norm is not None: + x = self.norm(x) + if all_intermediate is not None: all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) - if intermediate is not None and final_layer_norm_intermediate: + if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) return x, intermediate