mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-19 05:27:24 +08:00
Default to fused rms_norm
This commit is contained in:
parent
cbaa07bf05
commit
ee728a795f
@ -3,15 +3,8 @@ import comfy.model_management
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6, fused=True):
|
||||
if not fused: # compatibility mode as torch native rms_norm results are slightly different
|
||||
orig_dtype = x.dtype
|
||||
normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5)
|
||||
if weight is not None:
|
||||
weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device)
|
||||
normed = normed * weight
|
||||
return normed.to(orig_dtype)
|
||||
|
||||
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
|
||||
@ -11,6 +11,12 @@ from comfy.rmsnorm import rms_norm
|
||||
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
|
||||
|
||||
|
||||
# Intentional minor divergences from transformers -reference implementation:
|
||||
# Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor.
|
||||
# RMSNorm uses torch fused F.rms_norm
|
||||
# Input image and audio resizing/resampling slightly different numerically
|
||||
|
||||
|
||||
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
|
||||
@ -45,7 +51,6 @@ class Gemma4Config:
|
||||
num_kv_shared_layers: int = 18
|
||||
use_double_wide_mlp: bool = False
|
||||
stop_tokens = [1, 50, 106]
|
||||
fused_rms_norm: bool = True # False: to match reference code's exact numerical behavior, which is much slower, so we default to True
|
||||
vision_config = GEMMA4_VISION_CONFIG
|
||||
audio_config = GEMMA4_AUDIO_CONFIG
|
||||
mm_tokens_per_image = 280
|
||||
@ -99,11 +104,10 @@ class Gemma4Attention(nn.Module):
|
||||
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
fused = config.fused_rms_norm
|
||||
if config.q_norm == "gemma3":
|
||||
self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
if config.k_norm == "gemma3":
|
||||
self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -132,7 +136,7 @@ class Gemma4Attention(nn.Module):
|
||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||
if self.k_norm is not None:
|
||||
xk = self.k_norm(xk)
|
||||
xv = rms_norm(xv, fused=False)
|
||||
xv = rms_norm(xv)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
xq = _apply_rotary_pos_emb(xq, freqs_cis)
|
||||
@ -189,17 +193,16 @@ class TransformerBlockGemma4(nn.Module):
|
||||
mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size)
|
||||
|
||||
fused = config.fused_rms_norm
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
|
||||
if self.hidden_size_per_layer_input:
|
||||
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
|
||||
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused)
|
||||
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
|
||||
else:
|
||||
self.layer_scalar = None
|
||||
@ -251,7 +254,6 @@ class Gemma4Transformer(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
fused = config.fused_rms_norm
|
||||
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
|
||||
@ -260,7 +262,7 @@ class Gemma4Transformer(nn.Module):
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.final_norm else None
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None
|
||||
|
||||
# Precompute RoPE inv_freq on CPU to match reference code's exact value
|
||||
rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2)
|
||||
@ -282,7 +284,7 @@ class Gemma4Transformer(nn.Module):
|
||||
bias=False, device=device, dtype=dtype)
|
||||
self.per_layer_projection_norm = RMSNorm(
|
||||
self.hidden_size_per_layer_input, eps=config.rms_norm_eps,
|
||||
device=device, dtype=dtype, fused=fused)
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
for kv in past_key_values:
|
||||
@ -533,8 +535,8 @@ class Gemma4VisionAttention(nn.Module):
|
||||
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
|
||||
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, freqs, attention_mask=None):
|
||||
batch_size, seq_length, _ = x.shape
|
||||
@ -545,7 +547,7 @@ class Gemma4VisionAttention(nn.Module):
|
||||
|
||||
xq = self.q_norm(xq).transpose(1, 2)
|
||||
xk = self.k_norm(xk).transpose(1, 2)
|
||||
xv = rms_norm(xv, fused=False)
|
||||
xv = rms_norm(xv)
|
||||
|
||||
xq = _apply_vision_2d_rope(xq, freqs)
|
||||
xk = _apply_vision_2d_rope(xk, freqs)
|
||||
@ -561,7 +563,7 @@ class Gemma4VisionLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
|
||||
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
hidden = config["hidden_size"]
|
||||
self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
|
||||
self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs)
|
||||
@ -703,7 +705,7 @@ class Gemma4RMSNormProjector(nn.Module):
|
||||
self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embedding_projection(rms_norm(x, fused=False))
|
||||
return self.embedding_projection(rms_norm(x))
|
||||
|
||||
|
||||
class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
|
||||
@ -753,10 +755,10 @@ class Gemma4AudioFeedForward(nn.Module):
|
||||
super().__init__()
|
||||
hidden_size = config["hidden_size"]
|
||||
intermediate_size = config.get("intermediate_size", hidden_size * 4)
|
||||
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.post_layer_scale = config.get("residual_weight", 0.5)
|
||||
|
||||
def forward(self, x):
|
||||
@ -897,12 +899,12 @@ class Gemma4AudioLConv1d(nn.Module):
|
||||
super().__init__()
|
||||
hidden_size = config["hidden_size"]
|
||||
conv_kernel_size = config.get("conv_kernel_size", 5)
|
||||
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
|
||||
# Causal conv: left-pad only
|
||||
self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
|
||||
self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x):
|
||||
@ -925,7 +927,7 @@ class Gemma4AudioLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
|
||||
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
|
||||
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
hidden_size = config["hidden_size"]
|
||||
self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
|
||||
self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
|
||||
@ -1007,9 +1009,7 @@ class Gemma4_Tokenizer():
|
||||
waveform = waveform.unsqueeze(0)
|
||||
audio = waveform.squeeze(0).float().numpy()
|
||||
if sample_rate != 16000:
|
||||
# import librosa
|
||||
# audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
||||
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (still not full match)
|
||||
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
|
||||
from scipy.signal import resample_poly, firwin
|
||||
from math import gcd
|
||||
g = gcd(sample_rate, 16000)
|
||||
|
||||
@ -382,19 +382,18 @@ class Gemma3_12B_Config:
|
||||
stop_tokens = [1, 106]
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
self.add = add
|
||||
self.fused = fused
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
w = self.weight
|
||||
if self.add:
|
||||
w = w + 1.0
|
||||
|
||||
return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused)
|
||||
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user