Default to fused rms_norm

This commit is contained in:
kijai 2026-04-30 14:30:52 +03:00
parent cbaa07bf05
commit ee728a795f
3 changed files with 31 additions and 39 deletions

View File

@ -3,15 +3,8 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6, fused=True): # Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
if not fused: # compatibility mode as torch native rms_norm results are slightly different def rms_norm(x, weight=None, eps=1e-6):
orig_dtype = x.dtype
normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5)
if weight is not None:
weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device)
normed = normed * weight
return normed.to(orig_dtype)
if weight is None: if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else: else:

View File

@ -11,6 +11,12 @@ from comfy.rmsnorm import rms_norm
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
# Intentional minor divergences from transformers -reference implementation:
# Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor.
# RMSNorm uses torch fused F.rms_norm
# Input image and audio resizing/resampling slightly different numerically
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
@ -45,7 +51,6 @@ class Gemma4Config:
num_kv_shared_layers: int = 18 num_kv_shared_layers: int = 18
use_double_wide_mlp: bool = False use_double_wide_mlp: bool = False
stop_tokens = [1, 50, 106] stop_tokens = [1, 50, 106]
fused_rms_norm: bool = True # False: to match reference code's exact numerical behavior, which is much slower, so we default to True
vision_config = GEMMA4_VISION_CONFIG vision_config = GEMMA4_VISION_CONFIG
audio_config = GEMMA4_AUDIO_CONFIG audio_config = GEMMA4_AUDIO_CONFIG
mm_tokens_per_image = 280 mm_tokens_per_image = 280
@ -99,11 +104,10 @@ class Gemma4Attention(nn.Module):
self.q_norm = None self.q_norm = None
self.k_norm = None self.k_norm = None
fused = config.fused_rms_norm
if config.q_norm == "gemma3": if config.q_norm == "gemma3":
self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
if config.k_norm == "gemma3": if config.k_norm == "gemma3":
self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward( def forward(
self, self,
@ -132,7 +136,7 @@ class Gemma4Attention(nn.Module):
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
if self.k_norm is not None: if self.k_norm is not None:
xk = self.k_norm(xk) xk = self.k_norm(xk)
xv = rms_norm(xv, fused=False) xv = rms_norm(xv)
xk = xk.transpose(1, 2) xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2) xv = xv.transpose(1, 2)
xq = _apply_rotary_pos_emb(xq, freqs_cis) xq = _apply_rotary_pos_emb(xq, freqs_cis)
@ -189,17 +193,16 @@ class TransformerBlockGemma4(nn.Module):
mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size)
fused = config.fused_rms_norm self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input: if self.hidden_size_per_layer_input:
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
else: else:
self.layer_scalar = None self.layer_scalar = None
@ -251,7 +254,6 @@ class Gemma4Transformer(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
self.config = config self.config = config
fused = config.fused_rms_norm
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
@ -260,7 +262,7 @@ class Gemma4Transformer(nn.Module):
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.final_norm else None self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None
# Precompute RoPE inv_freq on CPU to match reference code's exact value # Precompute RoPE inv_freq on CPU to match reference code's exact value
rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2)
@ -282,7 +284,7 @@ class Gemma4Transformer(nn.Module):
bias=False, device=device, dtype=dtype) bias=False, device=device, dtype=dtype)
self.per_layer_projection_norm = RMSNorm( self.per_layer_projection_norm = RMSNorm(
self.hidden_size_per_layer_input, eps=config.rms_norm_eps, self.hidden_size_per_layer_input, eps=config.rms_norm_eps,
device=device, dtype=dtype, fused=fused) device=device, dtype=dtype)
def get_past_len(self, past_key_values): def get_past_len(self, past_key_values):
for kv in past_key_values: for kv in past_key_values:
@ -533,8 +535,8 @@ class Gemma4VisionAttention(nn.Module):
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
def forward(self, x, freqs, attention_mask=None): def forward(self, x, freqs, attention_mask=None):
batch_size, seq_length, _ = x.shape batch_size, seq_length, _ = x.shape
@ -545,7 +547,7 @@ class Gemma4VisionAttention(nn.Module):
xq = self.q_norm(xq).transpose(1, 2) xq = self.q_norm(xq).transpose(1, 2)
xk = self.k_norm(xk).transpose(1, 2) xk = self.k_norm(xk).transpose(1, 2)
xv = rms_norm(xv, fused=False) xv = rms_norm(xv)
xq = _apply_vision_2d_rope(xq, freqs) xq = _apply_vision_2d_rope(xq, freqs)
xk = _apply_vision_2d_rope(xk, freqs) xk = _apply_vision_2d_rope(xk, freqs)
@ -561,7 +563,7 @@ class Gemma4VisionLayer(nn.Module):
super().__init__() super().__init__()
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
hidden = config["hidden_size"] hidden = config["hidden_size"]
self.input_layernorm = RMSNorm(hidden, **norm_kwargs) self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs)
@ -703,7 +705,7 @@ class Gemma4RMSNormProjector(nn.Module):
self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
def forward(self, x): def forward(self, x):
return self.embedding_projection(rms_norm(x, fused=False)) return self.embedding_projection(rms_norm(x))
class Gemma4MultiModalProjector(Gemma4RMSNormProjector): class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
@ -753,10 +755,10 @@ class Gemma4AudioFeedForward(nn.Module):
super().__init__() super().__init__()
hidden_size = config["hidden_size"] hidden_size = config["hidden_size"]
intermediate_size = config.get("intermediate_size", hidden_size * 4) intermediate_size = config.get("intermediate_size", hidden_size * 4)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.post_layer_scale = config.get("residual_weight", 0.5) self.post_layer_scale = config.get("residual_weight", 0.5)
def forward(self, x): def forward(self, x):
@ -897,12 +899,12 @@ class Gemma4AudioLConv1d(nn.Module):
super().__init__() super().__init__()
hidden_size = config["hidden_size"] hidden_size = config["hidden_size"]
conv_kernel_size = config.get("conv_kernel_size", 5) conv_kernel_size = config.get("conv_kernel_size", 5)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
# Causal conv: left-pad only # Causal conv: left-pad only
self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
def forward(self, x): def forward(self, x):
@ -925,7 +927,7 @@ class Gemma4AudioLayer(nn.Module):
super().__init__() super().__init__()
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
hidden_size = config["hidden_size"] hidden_size = config["hidden_size"]
self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
@ -1007,9 +1009,7 @@ class Gemma4_Tokenizer():
waveform = waveform.unsqueeze(0) waveform = waveform.unsqueeze(0)
audio = waveform.squeeze(0).float().numpy() audio = waveform.squeeze(0).float().numpy()
if sample_rate != 16000: if sample_rate != 16000:
# import librosa # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
# audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (still not full match)
from scipy.signal import resample_poly, firwin from scipy.signal import resample_poly, firwin
from math import gcd from math import gcd
g = gcd(sample_rate, 16000) g = gcd(sample_rate, 16000)

View File

@ -382,19 +382,18 @@ class Gemma3_12B_Config:
stop_tokens = [1, 106] stop_tokens = [1, 106]
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
self.add = add self.add = add
self.fused = fused
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
w = self.weight w = self.weight
if self.add: if self.add:
w = w + 1.0 w = w + 1.0
return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused) return comfy.ldm.common_dit.rms_norm(x, w, self.eps)