Default to fused rms_norm

This commit is contained in:
kijai 2026-04-30 14:30:52 +03:00
parent cbaa07bf05
commit ee728a795f
3 changed files with 31 additions and 39 deletions

View File

@ -3,15 +3,8 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6, fused=True):
if not fused: # compatibility mode as torch native rms_norm results are slightly different
orig_dtype = x.dtype
normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5)
if weight is not None:
weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device)
normed = normed * weight
return normed.to(orig_dtype)
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else:

View File

@ -11,6 +11,12 @@ from comfy.rmsnorm import rms_norm
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
# Intentional minor divergences from transformers -reference implementation:
# Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor.
# RMSNorm uses torch fused F.rms_norm
# Input image and audio resizing/resampling slightly different numerically
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
@ -45,7 +51,6 @@ class Gemma4Config:
num_kv_shared_layers: int = 18
use_double_wide_mlp: bool = False
stop_tokens = [1, 50, 106]
fused_rms_norm: bool = True # False: to match reference code's exact numerical behavior, which is much slower, so we default to True
vision_config = GEMMA4_VISION_CONFIG
audio_config = GEMMA4_AUDIO_CONFIG
mm_tokens_per_image = 280
@ -99,11 +104,10 @@ class Gemma4Attention(nn.Module):
self.q_norm = None
self.k_norm = None
fused = config.fused_rms_norm
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(
self,
@ -132,7 +136,7 @@ class Gemma4Attention(nn.Module):
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
if self.k_norm is not None:
xk = self.k_norm(xk)
xv = rms_norm(xv, fused=False)
xv = rms_norm(xv)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
xq = _apply_rotary_pos_emb(xq, freqs_cis)
@ -189,17 +193,16 @@ class TransformerBlockGemma4(nn.Module):
mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size)
fused = config.fused_rms_norm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
else:
self.layer_scalar = None
@ -251,7 +254,6 @@ class Gemma4Transformer(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
fused = config.fused_rms_norm
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
@ -260,7 +262,7 @@ class Gemma4Transformer(nn.Module):
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.final_norm else None
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None
# Precompute RoPE inv_freq on CPU to match reference code's exact value
rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2)
@ -282,7 +284,7 @@ class Gemma4Transformer(nn.Module):
bias=False, device=device, dtype=dtype)
self.per_layer_projection_norm = RMSNorm(
self.hidden_size_per_layer_input, eps=config.rms_norm_eps,
device=device, dtype=dtype, fused=fused)
device=device, dtype=dtype)
def get_past_len(self, past_key_values):
for kv in past_key_values:
@ -533,8 +535,8 @@ class Gemma4VisionAttention(nn.Module):
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
def forward(self, x, freqs, attention_mask=None):
batch_size, seq_length, _ = x.shape
@ -545,7 +547,7 @@ class Gemma4VisionAttention(nn.Module):
xq = self.q_norm(xq).transpose(1, 2)
xk = self.k_norm(xk).transpose(1, 2)
xv = rms_norm(xv, fused=False)
xv = rms_norm(xv)
xq = _apply_vision_2d_rope(xq, freqs)
xk = _apply_vision_2d_rope(xk, freqs)
@ -561,7 +563,7 @@ class Gemma4VisionLayer(nn.Module):
super().__init__()
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
hidden = config["hidden_size"]
self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs)
@ -703,7 +705,7 @@ class Gemma4RMSNormProjector(nn.Module):
self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.embedding_projection(rms_norm(x, fused=False))
return self.embedding_projection(rms_norm(x))
class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
@ -753,10 +755,10 @@ class Gemma4AudioFeedForward(nn.Module):
super().__init__()
hidden_size = config["hidden_size"]
intermediate_size = config.get("intermediate_size", hidden_size * 4)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.post_layer_scale = config.get("residual_weight", 0.5)
def forward(self, x):
@ -897,12 +899,12 @@ class Gemma4AudioLConv1d(nn.Module):
super().__init__()
hidden_size = config["hidden_size"]
conv_kernel_size = config.get("conv_kernel_size", 5)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
# Causal conv: left-pad only
self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
def forward(self, x):
@ -925,7 +927,7 @@ class Gemma4AudioLayer(nn.Module):
super().__init__()
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
hidden_size = config["hidden_size"]
self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
@ -1007,9 +1009,7 @@ class Gemma4_Tokenizer():
waveform = waveform.unsqueeze(0)
audio = waveform.squeeze(0).float().numpy()
if sample_rate != 16000:
# import librosa
# audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (still not full match)
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
from scipy.signal import resample_poly, firwin
from math import gcd
g = gcd(sample_rate, 16000)

View File

@ -382,19 +382,18 @@ class Gemma3_12B_Config:
stop_tokens = [1, 106]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
self.add = add
self.fused = fused
def forward(self, x: torch.Tensor):
w = self.weight
if self.add:
w = w + 1.0
return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused)
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)