From 6718be09bae34e941aa62f5447891099668cb12f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 10 Apr 2026 15:28:26 +0300 Subject: [PATCH] cleanup, enable fused rms norm by default --- comfy/rmsnorm.py | 2 +- comfy/text_encoders/gemma4.py | 22 ++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 5e5ef359a..af0978341 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -4,7 +4,7 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm def rms_norm(x, weight=None, eps=1e-6, fused=True): - if not fused: + if not fused: # compatibility mode as torch native rms_norm results are slightly different orig_dtype = x.dtype normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5) if weight is not None: diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 9fac8c66a..1442f63a7 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -45,13 +45,11 @@ class Gemma4Config: num_kv_shared_layers: int = 18 use_double_wide_mlp: bool = False stop_tokens = [1, 50, 106] - fused_rms_norm: bool = False # True = use fused F.rms_norm (~64% faster, minor output difference from reference) + fused_rms_norm: bool = True # True = use fused F.rms_norm (lot faster, minor output difference from reference) vision_config = GEMMA4_VISION_CONFIG audio_config = GEMMA4_AUDIO_CONFIG mm_tokens_per_image = 280 -Gemma4_E4B_Config = Gemma4Config - @dataclass class Gemma4_E2B_Config(Gemma4Config): hidden_size: int = 1536 @@ -104,7 +102,7 @@ class Gemma4Attention(nn.Module): self.q_norm = None self.k_norm = None - fused = getattr(config, 'fused_rms_norm', False) + fused = config.fused_rms_norm if config.q_norm == "gemma3": self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.k_norm == "gemma3": @@ -188,18 +186,18 @@ class TransformerBlockGemma4(nn.Module): self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) - num_kv_shared = getattr(config, 'num_kv_shared_layers', 0) + num_kv_shared = config.num_kv_shared_layers first_kv_shared = config.num_hidden_layers - num_kv_shared - mlp_size = config.intermediate_size * 2 if getattr(config, 'use_double_wide_mlp', False) and index >= first_kv_shared else None + mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) - fused = getattr(config, 'fused_rms_norm', False) + fused = config.fused_rms_norm self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) @@ -255,7 +253,7 @@ class Gemma4Transformer(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config - fused = getattr(config, 'fused_rms_norm', False) + fused = config.fused_rms_norm self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) @@ -280,7 +278,7 @@ class Gemma4Transformer(nn.Module): self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False) # Per-layer input mechanism - self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype) self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False) @@ -354,7 +352,7 @@ class Gemma4Transformer(nn.Module): per_layer_inputs = per_layer_proj # KV sharing: later layers reuse KV from the last non-shared sliding/global layer - num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0) + num_kv_shared = self.config.num_kv_shared_layers first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers shared_sliding_kv = None # KV from last non-shared sliding layer shared_global_kv = None # KV from last non-shared global layer @@ -450,7 +448,7 @@ class Gemma4AudioMixin: class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base): def __init__(self, config_dict, dtype, device, operations): super().__init__() - self._init_model(Gemma4_E4B_Config(**config_dict), dtype, device, operations) + self._init_model(Gemma4Config(**config_dict), dtype, device, operations) self._init_audio(self.model.config, dtype, device, operations)