Add a way to disable the final norm in the llama based TE models. (#10794)

This commit is contained in:
comfyanonymous 2025-11-18 19:36:03 -08:00 committed by GitHub
parent b5c8be8b1d
commit 17027f2a6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,7 @@ class Llama2Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_3BConfig:
@ -53,6 +54,7 @@ class Qwen25_3BConfig:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_7BVLI_Config:
@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Gemma2_2B_Config:
@ -96,6 +99,7 @@ class Gemma2_2B_Config:
k_norm = None
sliding_attention = None
rope_scale = None
final_norm: bool = True
@dataclass
class Gemma3_4B_Config:
@ -118,6 +122,7 @@ class Gemma3_4B_Config:
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
final_norm: bool = True
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -366,7 +371,12 @@ class Llama2_(nn.Module):
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
@ -421,14 +431,16 @@ class Llama2_(nn.Module):
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if self.norm is not None:
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate:
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
return x, intermediate