diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index af0978341..e54be98d6 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -3,15 +3,8 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm -def rms_norm(x, weight=None, eps=1e-6, fused=True): - if not fused: # compatibility mode as torch native rms_norm results are slightly different - orig_dtype = x.dtype - normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5) - if weight is not None: - weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device) - normed = normed * weight - return normed.to(orig_dtype) - +# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding). +def rms_norm(x, weight=None, eps=1e-6): if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) else: diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 1b70eadb5..61ff42501 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -11,6 +11,12 @@ from comfy.rmsnorm import rms_norm from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding +# Intentional minor divergences from transformers -reference implementation: +# Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. +# RMSNorm uses torch fused F.rms_norm +# Input image and audio resizing/resampling slightly different numerically + + GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} @@ -45,7 +51,6 @@ class Gemma4Config: num_kv_shared_layers: int = 18 use_double_wide_mlp: bool = False stop_tokens = [1, 50, 106] - fused_rms_norm: bool = True # False: to match reference code's exact numerical behavior, which is much slower, so we default to True vision_config = GEMMA4_VISION_CONFIG audio_config = GEMMA4_AUDIO_CONFIG mm_tokens_per_image = 280 @@ -99,11 +104,10 @@ class Gemma4Attention(nn.Module): self.q_norm = None self.k_norm = None - fused = config.fused_rms_norm if config.q_norm == "gemma3": - self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.k_norm == "gemma3": - self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) def forward( self, @@ -132,7 +136,7 @@ class Gemma4Attention(nn.Module): xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) if self.k_norm is not None: xk = self.k_norm(xk) - xv = rms_norm(xv, fused=False) + xv = rms_norm(xv) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) xq = _apply_rotary_pos_emb(xq, freqs_cis) @@ -189,17 +193,16 @@ class TransformerBlockGemma4(nn.Module): mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) - fused = config.fused_rms_norm - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) - self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) else: self.layer_scalar = None @@ -251,7 +254,6 @@ class Gemma4Transformer(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config - fused = config.fused_rms_norm self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) @@ -260,7 +262,7 @@ class Gemma4Transformer(nn.Module): for i in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.final_norm else None + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None # Precompute RoPE inv_freq on CPU to match reference code's exact value rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) @@ -282,7 +284,7 @@ class Gemma4Transformer(nn.Module): bias=False, device=device, dtype=dtype) self.per_layer_projection_norm = RMSNorm( self.hidden_size_per_layer_input, eps=config.rms_norm_eps, - device=device, dtype=dtype, fused=fused) + device=device, dtype=dtype) def get_past_len(self, past_key_values): for kv in past_key_values: @@ -533,8 +535,8 @@ class Gemma4VisionAttention(nn.Module): self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) - self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) - self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) def forward(self, x, freqs, attention_mask=None): batch_size, seq_length, _ = x.shape @@ -545,7 +547,7 @@ class Gemma4VisionAttention(nn.Module): xq = self.q_norm(xq).transpose(1, 2) xk = self.k_norm(xk).transpose(1, 2) - xv = rms_norm(xv, fused=False) + xv = rms_norm(xv) xq = _apply_vision_2d_rope(xq, freqs) xk = _apply_vision_2d_rope(xk, freqs) @@ -561,7 +563,7 @@ class Gemma4VisionLayer(nn.Module): super().__init__() self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) - norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) hidden = config["hidden_size"] self.input_layernorm = RMSNorm(hidden, **norm_kwargs) self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) @@ -703,7 +705,7 @@ class Gemma4RMSNormProjector(nn.Module): self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) def forward(self, x): - return self.embedding_projection(rms_norm(x, fused=False)) + return self.embedding_projection(rms_norm(x)) class Gemma4MultiModalProjector(Gemma4RMSNormProjector): @@ -753,10 +755,10 @@ class Gemma4AudioFeedForward(nn.Module): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config.get("intermediate_size", hidden_size * 4) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) - self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.post_layer_scale = config.get("residual_weight", 0.5) def forward(self, x): @@ -897,12 +899,12 @@ class Gemma4AudioLConv1d(nn.Module): super().__init__() hidden_size = config["hidden_size"] conv_kernel_size = config.get("conv_kernel_size", 5) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) # Causal conv: left-pad only self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 - self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): @@ -925,7 +927,7 @@ class Gemma4AudioLayer(nn.Module): super().__init__() self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) - norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) hidden_size = config["hidden_size"] self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) @@ -1007,9 +1009,7 @@ class Gemma4_Tokenizer(): waveform = waveform.unsqueeze(0) audio = waveform.squeeze(0).float().numpy() if sample_rate != 16000: - # import librosa - # audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) - # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (still not full match) + # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) from scipy.signal import resample_poly, firwin from math import gcd g = gcd(sample_rate, 16000) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d1c43adb2..a34c41144 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -382,19 +382,18 @@ class Gemma3_12B_Config: stop_tokens = [1, 106] class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True): + def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.add = add - self.fused = fused def forward(self, x: torch.Tensor): w = self.weight if self.add: w = w + 1.0 - return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused) + return comfy.ldm.common_dit.rms_norm(x, w, self.eps)