cleanup, enable fused rms norm by default

This commit is contained in:
kijai 2026-04-10 15:28:26 +03:00
parent 05eaceafa1
commit 6718be09ba
2 changed files with 11 additions and 13 deletions

View File

@ -4,7 +4,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6, fused=True):
if not fused:
if not fused: # compatibility mode as torch native rms_norm results are slightly different
orig_dtype = x.dtype
normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5)
if weight is not None:

View File

@ -45,13 +45,11 @@ class Gemma4Config:
num_kv_shared_layers: int = 18
use_double_wide_mlp: bool = False
stop_tokens = [1, 50, 106]
fused_rms_norm: bool = False # True = use fused F.rms_norm (~64% faster, minor output difference from reference)
fused_rms_norm: bool = True # True = use fused F.rms_norm (lot faster, minor output difference from reference)
vision_config = GEMMA4_VISION_CONFIG
audio_config = GEMMA4_AUDIO_CONFIG
mm_tokens_per_image = 280
Gemma4_E4B_Config = Gemma4Config
@dataclass
class Gemma4_E2B_Config(Gemma4Config):
hidden_size: int = 1536
@ -104,7 +102,7 @@ class Gemma4Attention(nn.Module):
self.q_norm = None
self.k_norm = None
fused = getattr(config, 'fused_rms_norm', False)
fused = config.fused_rms_norm
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
if config.k_norm == "gemma3":
@ -188,18 +186,18 @@ class TransformerBlockGemma4(nn.Module):
self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops)
num_kv_shared = getattr(config, 'num_kv_shared_layers', 0)
num_kv_shared = config.num_kv_shared_layers
first_kv_shared = config.num_hidden_layers - num_kv_shared
mlp_size = config.intermediate_size * 2 if getattr(config, 'use_double_wide_mlp', False) and index >= first_kv_shared else None
mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size)
fused = getattr(config, 'fused_rms_norm', False)
fused = config.fused_rms_norm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
@ -255,7 +253,7 @@ class Gemma4Transformer(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
fused = getattr(config, 'fused_rms_norm', False)
fused = config.fused_rms_norm
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False)
@ -280,7 +278,7 @@ class Gemma4Transformer(nn.Module):
self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False)
# Per-layer input mechanism
self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype)
self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False)
@ -354,7 +352,7 @@ class Gemma4Transformer(nn.Module):
per_layer_inputs = per_layer_proj
# KV sharing: later layers reuse KV from the last non-shared sliding/global layer
num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0)
num_kv_shared = self.config.num_kv_shared_layers
first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers
shared_sliding_kv = None # KV from last non-shared sliding layer
shared_global_kv = None # KV from last non-shared global layer
@ -450,7 +448,7 @@ class Gemma4AudioMixin:
class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(Gemma4_E4B_Config(**config_dict), dtype, device, operations)
self._init_model(Gemma4Config(**config_dict), dtype, device, operations)
self._init_audio(self.model.config, dtype, device, operations)