From 832753f4970944f44d45d4b67646136dd19915e5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 3 Apr 2026 03:46:45 +0300 Subject: [PATCH 01/18] initial gemma4 support --- comfy/sd.py | 8 + comfy/text_encoders/gemma4.py | 1196 +++++++++++++++++++++++++++++++++ comfy/text_encoders/llama.py | 8 +- comfy_extras/nodes_textgen.py | 5 +- 4 files changed, 1212 insertions(+), 5 deletions(-) create mode 100644 comfy/text_encoders/gemma4.py diff --git a/comfy/sd.py b/comfy/sd.py index 5b6b59ea4..9b1960286 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,6 +62,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 +import comfy.text_encoders.gemma4 import comfy.model_patcher import comfy.lora @@ -1228,6 +1229,7 @@ class TEModel(Enum): QWEN35_4B = 25 QWEN35_9B = 26 QWEN35_27B = 27 + GEMMA_4_E4B = 28 def detect_te_model(sd): @@ -1253,6 +1255,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E4B if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1390,6 +1394,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model == TEModel.GEMMA_4_E4B: + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py new file mode 100644 index 000000000..c3a964cc4 --- /dev/null +++ b/comfy/text_encoders/gemma4.py @@ -0,0 +1,1196 @@ +import torch +import torch.nn as nn +from dataclasses import dataclass + +from comfy import sd1_clip +import comfy.utils +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.text_encoders.llama import RMSNorm, BaseLlama, BaseGenerate, Llama2_ + + +GEMMA4_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "model_type": "gemma4_vision", "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5, "gradient_clipping": 1e10, "hidden_act": "silu"} + +@dataclass +class Gemma4_E4B_Config: + vocab_size: int = 262144 + hidden_size: int = 2560 + intermediate_size: int = 10240 + num_hidden_layers: int = 42 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma4" + head_dim = 256 + global_head_dim = 512 + rms_norm_add = False + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [512, 512, 512, 512, 512, False] + rope_scale = None + partial_rotary_factor: float = 0.25 + final_norm: bool = True + lm_head: bool = False + final_logit_softcapping: float = 30.0 + hidden_size_per_layer_input: int = 256 + num_kv_shared_layers: int = 18 + stop_tokens = [1, 106] + vision_config = GEMMA4_VISION_CONFIG + audio_config = GEMMA4_AUDIO_CONFIG + mm_tokens_per_image = 280 + + +def precompute_freqs_cis_proportional(head_dim, partial_rotary_factor, position_ids, theta, device=None): + """Proportional RoPE: compute freqs for full head_dim, but only first rope_angles get non-zero frequencies.""" + rope_angles = int(partial_rotary_factor * head_dim // 2) + nope_angles = head_dim // 2 - rope_angles + + theta_numerator = torch.arange(0, 2 * rope_angles, 2, device=device).float() + inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) + + if nope_angles > 0: + inv_freq = torch.cat([inv_freq, torch.zeros(nope_angles, device=device)], dim=0) + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().unsqueeze(1) + sin = emb.sin().unsqueeze(1) + sin_split = sin.shape[-1] // 2 + return (cos, sin[..., :sin_split], -sin[..., sin_split:]) + + +class Gemma4Attention(nn.Module): + def __init__(self, config, head_dim, device=None, dtype=None, ops=None): + super().__init__() + from comfy.text_encoders.llama import RMSNorm + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + self.head_dim = head_dim + self.inner_size = self.num_heads * head_dim + + ops = ops or nn + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + + self.q_norm = None + self.k_norm = None + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + freqs_cis=None, + optimized_attention=None, + past_key_value=None, + sliding_window=None, + shared_kv=None, + ): + from comfy.text_encoders.llama import apply_rope + batch_size, seq_length, _ = hidden_states.shape + + xq = self.q_proj(hidden_states) + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + if self.q_norm is not None: + xq = self.q_norm(xq) + + if shared_kv is not None: + # KV-shared layer: borrow KV from source layer, skip own cache + if len(shared_kv) == 3: + xk, xv = shared_kv[0][:, :, :shared_kv[2]], shared_kv[1][:, :, :shared_kv[2]] + else: + xk, xv = shared_kv + # Apply RoPE to Q only (K already has RoPE from source layer) + xq, _ = apply_rope(xq, xq, freqs_cis=freqs_cis) # dummy K, only Q result used + present_key_value = None + shareable_kv = None + else: + xk = self.k_proj(hidden_states) + xv = self.v_proj(hidden_states) + xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + if self.k_norm is not None: + xk = self.k_norm(xk) + xv = _parameterless_rms_norm(xv) + xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + + present_key_value = None + if past_key_value is not None: + index = 0 + num_tokens = xk.shape[2] + if len(past_key_value) > 0: + past_key, past_value, index = past_key_value + if past_key.shape[2] >= (index + num_tokens): + past_key[:, :, index:index + xk.shape[2]] = xk + past_value[:, :, index:index + xv.shape[2]] = xv + xk = past_key[:, :, :index + xk.shape[2]] + xv = past_value[:, :, :index + xv.shape[2]] + present_key_value = (past_key, past_value, index + num_tokens) + else: + xk = torch.cat((past_key[:, :, :index], xk), dim=2) + xv = torch.cat((past_value[:, :, :index], xv), dim=2) + present_key_value = (xk, xv, index + num_tokens) + else: + present_key_value = (xk, xv, index + num_tokens) + + if sliding_window is not None and xk.shape[2] > sliding_window: + xk = xk[:, :, -sliding_window:] + xv = xv[:, :, -sliding_window:] + attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None + + # KV for sharing with later layers + shareable_kv = present_key_value if present_key_value is not None else (xk, xv) + + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + + # scaling=1.0: pre-multiply Q to cancel optimized_attention's 1/sqrt(head_dim) + xq = xq * (self.head_dim ** 0.5) + + output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) + return self.o_proj(output), present_key_value, shareable_kv + + +class TransformerBlockGemma4(nn.Module): + def __init__(self, config, index, device=None, dtype=None, ops=None): + super().__init__() + from comfy.text_encoders.llama import MLP + if config.sliding_attention is not None: + self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] + else: + self.sliding_attention = False + + head_dim = config.head_dim if self.sliding_attention else config.global_head_dim + + self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) + if self.hidden_size_per_layer_input: + ops_pl = ops or nn + self.per_layer_input_gate = ops_pl.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) + self.per_layer_projection = ops_pl.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) + else: + self.layer_scalar = None + + def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, + past_key_value=None, per_layer_input=None, shared_kv=None): + sliding_window = None + if self.sliding_attention: + sliding_window = self.sliding_attention + if x.shape[1] > self.sliding_attention: + sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype) + sliding_mask.tril_(diagonal=-self.sliding_attention) + attention_mask = attention_mask + sliding_mask if attention_mask is not None else sliding_mask + freqs_cis = freqs_cis[1] + else: + freqs_cis = freqs_cis[0] + + residual = x + x = self.input_layernorm(x) + x, present_key_value, shareable_kv = self.self_attn( + hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, + optimized_attention=optimized_attention, past_key_value=past_key_value, + sliding_window=sliding_window, shared_kv=shared_kv, + ) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + + if self.hidden_size_per_layer_input and per_layer_input is not None: + residual = x + x = self.per_layer_input_gate(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = x * per_layer_input + x = self.per_layer_projection(x) + x = self.post_per_layer_input_norm(x) + x = residual + x + + if self.layer_scalar is not None: + x = x * self.layer_scalar + + return x, present_key_value, shareable_kv + + +class Gemma4Transformer(Llama2_): + """Llama2_ subclass with Gemma4-specific features: per-layer inputs, KV sharing, proportional RoPE.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__(config, device=device, dtype=dtype, ops=ops) + # Override transformer type + self.normalize_in = True + # Replace layers with Gemma4 blocks + self.layers = nn.ModuleList([ + TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + # Per-layer input mechanism + self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) + if self.hidden_size_per_layer_input: + self.embed_tokens_per_layer = ops.Embedding( + config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, + device=device, dtype=dtype) + self.per_layer_model_projection = ops.Linear( + config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, device=device, dtype=dtype) + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, eps=config.rms_norm_eps, + add=config.rms_norm_add, device=device, dtype=dtype) + + def compute_freqs_cis(self, position_ids, device): + from comfy.text_encoders.llama import precompute_freqs_cis + global_freqs = precompute_freqs_cis_proportional( + self.config.global_head_dim, self.config.partial_rotary_factor, + position_ids, self.config.rope_theta[0], device=device) + sliding_freqs = precompute_freqs_cis( + self.config.head_dim, position_ids, self.config.rope_theta[1], device=device) + return [global_freqs, sliding_freqs] + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], + past_key_values=None, input_ids=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) + + if self.normalize_in: + x *= self.config.hidden_size ** 0.5 + + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = self.get_past_len(past_key_values) + + if position_ids is None: + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.compute_freqs_cis(position_ids, x.device) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) + + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) + mask = mask + causal_mask if mask is not None else causal_mask + + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + + # Per-layer inputs + per_layer_inputs = None + if self.hidden_size_per_layer_input: + num_layers = self.config.num_hidden_layers + hpl = self.hidden_size_per_layer_input + per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5)) + per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl)) + if input_ids is not None and input_ids.shape[1] == x.shape[1]: + per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl) * (hpl ** 0.5) + per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5) + else: + per_layer_inputs = per_layer_proj + + # KV sharing: only last sliding (22) and last global (23) layers store KV for sharing + num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0) + first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers + shared_sliding_kv = None # KV from last non-shared sliding layer + shared_global_kv = None # KV from last non-shared global layer + + intermediate = None + next_key_values = [] + for i, layer in enumerate(self.layers): + past_kv = past_key_values[i] if past_key_values is not None and len(past_key_values) > 0 else None + + layer_kwargs = {} + if per_layer_inputs is not None: + layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :] + if i >= first_kv_shared and num_kv_shared > 0: + is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention + shared = shared_sliding_kv if is_sliding else shared_global_kv + if shared is not None: + layer_kwargs['shared_kv'] = shared + + x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, + optimized_attention=optimized_attention, past_key_value=past_kv, **layer_kwargs) + + next_key_values.append(current_kv if current_kv is not None else ()) + + # Only track the last sliding/global before the sharing boundary + if i < first_kv_shared and shareable_kv is not None: + is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention + if is_sliding: + shared_sliding_kv = shareable_kv + else: + shared_global_kv = shareable_kv + + if i == intermediate_output: + intermediate = x.clone() + + if self.norm is not None: + x = self.norm(x) + + if len(next_key_values) > 0: + return x, intermediate, next_key_values + return x, intermediate + + +class Gemma4_E4B(BaseLlama, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma4_E4B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype, device, operations) + self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype, device, operations) + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations) + + def logits(self, x): + logits = super().logits(x) + cap = self.model.config.final_logit_softcapping + if cap: + logits = cap * torch.tanh(logits / cap) + return logits + + def init_kv_cache(self, batch, max_cache_len, device, execution_dtype): + config = self.model.config + num_kv_shared = getattr(config, 'num_kv_shared_layers', 0) + first_kv_shared = config.num_hidden_layers - num_kv_shared + past_key_values = [] + for i in range(config.num_hidden_layers): + if i >= first_kv_shared: + past_key_values.append(()) # shared layers don't need KV cache + else: + sa = config.sliding_attention[i % len(config.sliding_attention)] + hd = config.head_dim if sa else config.global_head_dim + past_key_values.append(( + torch.empty([batch, config.num_key_value_heads, max_cache_len, hd], device=device, dtype=execution_dtype), + torch.empty([batch, config.num_key_value_heads, max_cache_len, hd], device=device, dtype=execution_dtype), + 0)) + return past_key_values + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = embed["data"].movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] + vision_out = self.vision_model(image.to(device, dtype=torch.float32)) + return self.multi_modal_projector(vision_out), None + if embed["type"] == "audio": + audio = embed["data"].to(device, dtype=torch.float32) + audio_out = self.audio_model(audio) + return self.audio_projector(audio_out), None + return None, None + + +# --- Vision Encoder --- +# Matches HF weight structure after conversion: +# vision_model.patch_embedder.input_proj.weight [768, 768] +# vision_model.patch_embedder.position_embedding_table [2, 10240, 768] +# vision_model.encoder.layers.X.self_attn.{q,k,v,o}_proj.weight [768, 768] +# vision_model.encoder.layers.X.self_attn.{q,k}_norm.weight [64] +# vision_model.encoder.layers.X.mlp.{gate,up}_proj.weight [3072, 768] +# vision_model.encoder.layers.X.mlp.down_proj.weight [768, 3072] +# vision_model.encoder.layers.X.{input,post_attention,pre_feedforward,post_feedforward}_layernorm.weight [768] + +def _parameterless_rms_norm(x, eps=1e-6): + """RMSNorm without learnable weight (used by Gemma4 v_norm and projectors).""" + mean_squared = x.float().pow(2).mean(-1, keepdim=True) + eps + return (x.float() * torch.pow(mean_squared, -0.5)).to(x.dtype) + + +def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): + """Compute 2D RoPE for vision: separate frequencies for x and y dimensions. + + Args: + head_dim: dimension per head (e.g. 64) + pixel_position_ids: [batch, num_patches, 2] with (x, y) coords + theta: RoPE base frequency + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim] + """ + rotary_dim_per_axis = head_dim // 2 + freq_indices = torch.arange(0, rotary_dim_per_axis, 2, device=device).float() + inv_freq = 1.0 / (theta ** (freq_indices / rotary_dim_per_axis)) + + all_cos, all_sin = [], [] + for i in range(2): # x and y + dim_positions = pixel_position_ids[:, :, i].float() # [batch, num_patches] + freqs = torch.einsum('bi,j->bij', dim_positions, inv_freq.to(device)) # [batch, num_patches, rotary_dim/2] + emb = torch.cat([freqs, freqs], dim=-1) # [batch, num_patches, rotary_dim] + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + + cos = torch.cat(all_cos, dim=-1).to(pixel_position_ids.device) # [batch, num_patches, head_dim] + sin = torch.cat(all_sin, dim=-1).to(pixel_position_ids.device) + return cos, sin + + +def _apply_vision_2d_rope(x, cos, sin): + """Apply 2D RoPE (multidimensional) to vision query/key states. + + Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently. + + x: [batch, heads, seq, head_dim] + cos, sin: [batch, seq, head_dim] + """ + cos = cos.unsqueeze(1) # [batch, 1, seq, head_dim] + sin = sin.unsqueeze(1) + + def rotate_half(t): + t1 = t[..., :t.shape[-1]//2] + t2 = t[..., t.shape[-1]//2:] + return torch.cat((-t2, t1), dim=-1) + + # Split into 2 parts (y and x dimensions) + half = x.shape[-1] // 2 + x_parts = [x[..., :half], x[..., half:]] + cos_parts = [cos[..., :half], cos[..., half:]] + sin_parts = [sin[..., :half], sin[..., half:]] + + rotated_parts = [] + for xp, cp, sp in zip(x_parts, cos_parts, sin_parts): + rotated_parts.append((xp * cp) + (rotate_half(xp) * sp)) + + return torch.cat(rotated_parts, dim=-1) + + +class ClippedLinear(nn.Module): + """Linear layer with activation clipping (from quantization-aware training). + + Stores input_max/min and output_max/min as buffers loaded from checkpoint. + """ + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, operations=None): + super().__init__() + ops = operations or nn + self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + self.register_buffer('output_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('output_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + + @property + def weight(self): + return self.linear.weight + + def forward(self, x): + x = x.clamp(min=self.input_min, max=self.input_max) + x = self.linear(x) + x = x.clamp(min=self.output_min, max=self.output_max) + return x + + +def _make_clipped_linear(in_f, out_f, bias=False, device=None, dtype=None, operations=None): + return ClippedLinear(in_f, out_f, bias=bias, device=device, dtype=dtype, operations=operations) + + +class Gemma4VisionMLP(nn.Module): + """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config["intermediate_size"] + self.gate_proj = _make_clipped_linear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.up_proj = _make_clipped_linear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.down_proj = _make_clipped_linear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + + def forward(self, x): + return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) + + +class Gemma4VisionAttention(nn.Module): + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) + + self.q_proj = _make_clipped_linear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) + self.k_proj = _make_clipped_linear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) + self.v_proj = _make_clipped_linear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) + self.o_proj = _make_clipped_linear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations) + + self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + + def forward(self, x, cos_sin=None, attention_mask=None, optimized_attention=None): + batch_size, seq_length, _ = x.shape + + xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xv = _parameterless_rms_norm(xv) + + # Apply 2D RoPE + if cos_sin is not None: + cos, sin = cos_sin + xq = xq.transpose(1, 2) # [B, H, S, D] + xk = xk.transpose(1, 2) + xq = _apply_vision_2d_rope(xq, cos, sin) + xk = _apply_vision_2d_rope(xk, cos, sin) + else: + xq = xq.transpose(1, 2) + xk = xk.transpose(1, 2) + + xv = xv.to(xq.dtype).transpose(1, 2) + + # scaling=1.0 (Q/K already normalized), cancel optimized_attention's 1/sqrt(d) + xq = xq * (self.head_dim ** 0.5) + + output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) + return self.o_proj(output) + + +class Gemma4VisionLayer(nn.Module): + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, operations=operations) + self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, operations=operations) + self.input_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + + def forward(self, x, cos_sin=None, attention_mask=None, optimized_attention=None): + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask, optimized_attention=optimized_attention) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class Gemma4PatchEmbedder(nn.Module): + """Patch embedding with learned 2D position embeddings via one-hot lookup.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + hidden_size = config["hidden_size"] + patch_size = config["patch_size"] + self.patch_size = patch_size + self.position_embedding_size = config.get("position_embedding_size", 10240) + + self.input_proj = operations.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) + self.position_embedding_table = nn.Parameter( + torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) + ) + + def forward(self, pixel_values, pixel_position_ids): + """ + pixel_values: [B, C, H, W] normalized as 2*(x-0.5) + pixel_position_ids: [B, num_patches, 2] with (x,y) positions + """ + batch_size, channels, height, width = pixel_values.shape + patches_h = height // self.patch_size + patches_w = width // self.patch_size + + # Extract and flatten patches: [B, num_patches, 3*patch_size^2] + x = pixel_values.reshape(batch_size, channels, patches_h, self.patch_size, patches_w, self.patch_size) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, patches_h * patches_w, -1) + + hidden_states = self.input_proj(x.to(self.input_proj.weight.dtype)) + + # Position embeddings via one-hot lookup + clamped_positions = pixel_position_ids.clamp(min=0) + one_hot = torch.nn.functional.one_hot(clamped_positions, num_classes=self.position_embedding_size) + pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) + one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table) # [B, 2, num_patches, pos_size] + position_embeddings = one_hot @ pos_table # [B, 2, num_patches, hidden] + position_embeddings = position_embeddings.sum(dim=1) # [B, num_patches, hidden] + + return hidden_states + position_embeddings + + +class Gemma4VisionEncoderLayers(nn.Module): + """Wrapper to produce state dict keys as encoder.layers.X.*""" + def __init__(self, config, dtype=None, device=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + Gemma4VisionLayer(config, device=device, dtype=dtype, operations=operations) + for _ in range(config["num_hidden_layers"]) + ]) + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, operations=None): + super().__init__() + self.config = config + self.hidden_size = config["hidden_size"] + self.head_dim = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]) + self.patch_size = config["patch_size"] + self.pooling_kernel_size = config.get("pooling_kernel_size", 3) + self.root_hidden_size = self.hidden_size ** 0.5 + + self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, operations=operations) + self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, operations=operations) + + def forward(self, pixel_values): + """ + pixel_values: [B, C, H, W] in [0, 1] range + Returns: [B, output_tokens, hidden_size] projected vision tokens + """ + batch_size, channels, height, width = pixel_values.shape + patches_h = height // self.patch_size + patches_w = width // self.patch_size + num_patches = patches_h * patches_w + + # Generate position IDs: grid of (col, row) per patch + # HF processor uses (x=col, y=row) convention for position_ids + rows = torch.arange(patches_h, device=pixel_values.device) + cols = torch.arange(patches_w, device=pixel_values.device) + grid_y, grid_x = torch.meshgrid(rows, cols, indexing='ij') + pixel_position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) # [num_patches, 2] + pixel_position_ids = pixel_position_ids.unsqueeze(0).expand(batch_size, -1, -1) # [B, num_patches, 2] + + # Patch embedding + position embedding + x = self.patch_embedder(pixel_values, pixel_position_ids) + + # Compute 2D RoPE cos/sin for attention + cos_sin = _compute_vision_2d_rope(self.head_dim, pixel_position_ids, device=pixel_values.device) + + optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True) + + for layer in self.encoder.layers: + x = layer(x, cos_sin=cos_sin, optimized_attention=optimized_attention) + + # Pooling: position-aware average pooling matching HF's Gemma4VisionPooler + k = self.pooling_kernel_size # 3 + k_squared = k * k + output_length = num_patches // k_squared + if num_patches != output_length and output_length > 0: + # Assign each patch to a kernel block based on its (col, row) position + kernel_col = pixel_position_ids[:, :, 0] // k # col // k + kernel_row = pixel_position_ids[:, :, 1] // k # row // k + stride = patches_w // k # matches HF's (max_x + 1) // k + kernel_idxs = kernel_col + stride * kernel_row # [B, num_patches] + + # One-hot assignment matrix and weighted average + weights = torch.nn.functional.one_hot(kernel_idxs.long(), output_length).float() / k_squared + x = (weights.transpose(1, 2) @ x.float()).to(x.dtype) # [B, output_length, hidden] + + # Scale by sqrt(hidden_size) like HF pooler + x = x * self.root_hidden_size + return x + + +class Gemma4MultiModalProjector(nn.Module): + def __init__(self, config, dtype=None, device=None, operations=None): + super().__init__() + vision_hidden_size = config.vision_config["hidden_size"] + text_hidden_size = config.hidden_size + self.embedding_projection = operations.Linear(vision_hidden_size, text_hidden_size, bias=False, device=device, dtype=dtype) + + def forward(self, vision_outputs): + return self.embedding_projection(_parameterless_rms_norm(vision_outputs)) + + +# --- Audio Encoder --- +# Conformer-style architecture matching HF weight structure after conversion: +# audio_model.subsample_conv_projection.layer0.conv.weight [128, 1, 3, 3] +# audio_model.subsample_conv_projection.layer0.norm.weight [128] +# audio_model.subsample_conv_projection.layer1.conv.weight [32, 128, 3, 3] +# audio_model.subsample_conv_projection.layer1.norm.weight [32] +# audio_model.subsample_conv_projection.input_proj_linear.weight [1024, 1024] +# audio_model.layers.X.feed_forward1.{pre,post}_layer_norm.weight [1024] +# audio_model.layers.X.feed_forward1.ffw_layer_{1,2}.weight [4096/1024, 1024/4096] +# audio_model.layers.X.self_attn.{q,k,v}_proj.weight [1024, 1024] +# audio_model.layers.X.self_attn.post.weight [1024, 1024] +# audio_model.layers.X.self_attn.per_dim_scale [128] +# audio_model.layers.X.self_attn.relative_k_proj.weight [1024, 1024] +# audio_model.layers.X.lconv1d.{linear_start,linear_end}.weight, depthwise_conv1d.weight +# audio_model.layers.X.feed_forward2.* (same as feed_forward1) +# audio_model.output_proj.{weight, bias} + +class Gemma4AudioConvSubsampler(nn.Module): + """2D convolution subsampling for audio features, matching HF Gemma4AudioSubSampleConvProjection.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + eps = config.get("rms_norm_eps", 1e-6) + self.layer0 = nn.ModuleDict({ + 'conv': operations.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': operations.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + self.layer1 = nn.ModuleDict({ + 'conv': operations.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': operations.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + # proj_input_dim = (128 // 4) * 32 = 1024 + self.input_proj_linear = operations.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) + + def forward(self, x): + # x: [batch, time, features] + x = x.unsqueeze(1) # [batch, 1, time, features] + x = self.layer0['conv'](x.to(self.layer0['conv'].weight.dtype)) + x = torch.relu(self.layer0['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + + x = self.layer1['conv'](x) + x = torch.relu(self.layer1['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + + batch_size, _, seq_len, _ = x.shape + x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) + return self.input_proj_linear(x) + + +class Gemma4AudioFeedForward(nn.Module): + """Conformer feed-forward with gradient clipping and residual scaling.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config.get("intermediate_size", hidden_size * 4) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.ffw_layer_1 = _make_clipped_linear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.ffw_layer_2 = _make_clipped_linear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.post_layer_scale = config.get("residual_weight", 0.5) + self.gradient_clipping = config.get("gradient_clipping", 1e10) + + def forward(self, x): + residual = x + gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) + x = torch.clamp(x, -gc, gc) + x = self.pre_layer_norm(x) + x = torch.nn.functional.silu(self.ffw_layer_1(x)) + x = self.ffw_layer_2(x) + x = torch.clamp(x, -gc, gc) + x = self.post_layer_norm(x) + x = x * self.post_layer_scale + return x + residual + + +class Gemma4AudioRelPositionalEncoding(nn.Module): + """Sinusoidal relative positional encoding for audio attention.""" + def __init__(self, config, device=None, dtype=None): + super().__init__() + hidden_size = config["hidden_size"] + chunk_size = config.get("attention_chunk_size", 12) + context_left = config.get("attention_context_left", 13) + context_right = config.get("attention_context_right", 0) + self.context_size = chunk_size + context_left - 1 + context_right + + import math + num_timescales = hidden_size // 2 + log_inc = math.log(10000.0) / max(num_timescales - 1, 1) + inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).unsqueeze(0).unsqueeze(0) + self.register_buffer("inv_timescales", inv_timescales, persistent=False) + + @torch.no_grad() + def forward(self, hidden_states): + chunk_size = 12 # matches HF hardcoded value + positions = torch.arange(chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1) + scaled = positions * self.inv_timescales.to(device=hidden_states.device) + return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype) + + +class Gemma4AudioAttention(nn.Module): + """Chunked block attention with relative position bias and softcap.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + import math + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = self.hidden_size // self.num_heads + self.chunk_size = config.get("attention_chunk_size", 12) + self.max_past_horizon = config.get("attention_context_left", 13) - 1 + self.max_future_horizon = config.get("attention_context_right", 0) + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.q_scale = (self.head_dim ** -0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + self.softcap = config.get("attention_logit_cap", 50.0) + + self.q_proj = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.k_proj = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.v_proj = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.post = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) + self.relative_k_proj = operations.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) + + def _convert_to_block(self, x): + B, S, H, D = x.shape + num_blocks = (S + self.chunk_size - 1) // self.chunk_size + pad = num_blocks * self.chunk_size - S + x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad)) + return x.reshape(B, num_blocks, self.chunk_size, H, D) + + def _extract_block_context(self, x): + B, S, H, D = x.shape + x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)) + x = x.unfold(1, self.context_size, self.chunk_size) + return torch.movedim(x, -1, 2).contiguous() + + def _rel_shift(self, x): + B, H, NB, BS, PL = x.shape + CS = self.context_size + x = torch.nn.functional.pad(x, (0, CS + 1 - PL)) + x = x.view(B, H, NB, BS * (CS + 1)) + x = x[..., :BS * CS] + return x.view(B, H, NB, BS, CS) + + def forward(self, x, position_embeddings=None): + B, S, _ = x.shape + + q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim) + k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim) + v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim) + + q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale.float()) + k = k * self.k_scale + + q_blocks = self._convert_to_block(q) + k_context = self._extract_block_context(k) + v_context = self._extract_block_context(v) + num_blocks = q_blocks.shape[1] + + rel_k = self.relative_k_proj(position_embeddings).view(-1, self.num_heads, self.head_dim).to(q.dtype) + + queries = q_blocks.permute(0, 3, 1, 2, 4) # [B, H, NB, CS, D] + matrix_ac = queries @ k_context.permute(0, 3, 1, 4, 2) + + queries_flat = queries.reshape(B, self.num_heads, -1, self.head_dim) + matrix_bd = queries_flat @ rel_k.permute(1, 2, 0) + matrix_bd = matrix_bd.reshape(B, self.num_heads, num_blocks, self.chunk_size, -1) + matrix_bd = self._rel_shift(matrix_bd) + + attn_weights = matrix_ac + matrix_bd + attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype) + out = attn_weights @ v_context.permute(0, 3, 1, 2, 4) + out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1) + out = out[:, :S].contiguous() + return self.post(out.to(self.post.linear.weight.dtype)) + + +class Gemma4AudioLConv1d(nn.Module): + """Lightweight convolution with standard GLU.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + hidden_size = config["hidden_size"] + conv_kernel_size = config.get("conv_kernel_size", 5) + self.gradient_clipping = config.get("gradient_clipping", 1e10) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.linear_start = _make_clipped_linear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations) + # Causal conv: left-pad only (no right padding) + self.depthwise_conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) + self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 + self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.linear_end = _make_clipped_linear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = self.linear_start(x) + x = torch.nn.functional.glu(x, dim=-1) # standard GLU, not gelu-gated + x = x.transpose(1, 2) + x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) + x = self.depthwise_conv1d(x).transpose(1, 2) + gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) + x = torch.clamp(x, -gc, gc) + x = self.conv_norm(x) + x = torch.nn.functional.silu(x) + x = self.linear_end(x) + return x + residual + + +class Gemma4AudioLayer(nn.Module): + """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" + def __init__(self, config, device=None, dtype=None, operations=None): + super().__init__() + hidden_size = config["hidden_size"] + self.gradient_clipping = config.get("gradient_clipping", 1e10) + self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) + self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, operations=operations) + self.norm_pre_attn = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.norm_post_attn = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, operations=operations) + self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) + self.norm_out = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + + def forward(self, x, position_embeddings=None): + gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) + x = self.feed_forward1(x) + + residual = x + x = torch.clamp(x, -gc, gc) + x = self.norm_pre_attn(x) + x = self.self_attn(x, position_embeddings=position_embeddings) + x = torch.clamp(x, -gc, gc) + x = self.norm_post_attn(x) + x = x + residual + + x = self.lconv1d(x) + x = self.feed_forward2(x) + + x = torch.clamp(x, -gc, gc) + x = self.norm_out(x) + return x + + +class Gemma4AudioEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.output_proj_dims = config.get("output_proj_dims", 1536) + + self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, operations=operations) + self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype) + + self.layers = nn.ModuleList([ + Gemma4AudioLayer(config, device=device, dtype=dtype, operations=operations) + for _ in range(config["num_hidden_layers"]) + ]) + + self.output_proj = operations.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) + + def forward(self, audio_features): + x = self.subsample_conv_projection(audio_features) + position_embeddings = self.rel_pos_enc(x) + + for layer in self.layers: + x = layer(x, position_embeddings=position_embeddings) + + x = self.output_proj(x) + return x + + +class Gemma4AudioProjector(nn.Module): + def __init__(self, config, dtype=None, device=None, operations=None): + super().__init__() + audio_output_dim = config.get("audio_output_proj_dims", 1536) + text_hidden_size = config.get("text_hidden_size", 2560) + self.embedding_projection = operations.Linear(audio_output_dim, text_hidden_size, bias=False, device=device, dtype=dtype) + + def forward(self, audio_outputs): + return self.embedding_projection(_parameterless_rms_norm(audio_outputs)) + + +# --- Tokenizer & Wrappers --- + +class Gemma4_Tokenizer(): + def state_dict(self): + return {} + + def _extract_mel_spectrogram(self, waveform, sample_rate): + """Extract log mel spectrogram using HF's Gemma4AudioFeatureExtractor.""" + import torchaudio + from transformers.models.gemma4.feature_extraction_gemma4 import Gemma4AudioFeatureExtractor + if sample_rate != 16000: + waveform = torchaudio.functional.resample(waveform, sample_rate, 16000) + if waveform.dim() > 1 and waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + # Convert to numpy for HF feature extractor + audio_np = waveform.squeeze(0).numpy() + fe = Gemma4AudioFeatureExtractor() + result = fe([audio_np], return_tensors='pt') + return result['input_features'][0] # [T, 128] + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, llama_template=None, skip_template=True, thinking=False, **kwargs): + if thinking: + self.llama_template = "<|turn>system\n<|think|>\n<|turn>user\n{}\n<|turn>model\n" + self.llama_template_images = "<|turn>system\n<|think|>\n<|turn>user\n\n\n<|image><|image|>\n\n{}\n<|turn>model\n" + else: + self.llama_template = "<|turn>user\n{}\n<|turn>model\n" + self.llama_template_images = "<|turn>user\n\n\n<|image><|image|>\n\n{}\n<|turn>model\n" + + # Process audio + audio_features = [] + if audio is not None: + waveform = audio["waveform"].squeeze(0) if isinstance(audio, dict) else audio + sample_rate = audio.get("sample_rate", 16000) if isinstance(audio, dict) else 16000 + mel = self._extract_mel_spectrogram(waveform, sample_rate) + audio_features = [mel.unsqueeze(0)] # [1, T, 128] + + if image is None: + images = [] + else: + samples = image.movedim(-1, 1) # [B, C, H, W] + h, w = samples.shape[2], samples.shape[3] + # Aspect-ratio-preserving resize matching HF Gemma4ImageProcessor + patch_size = 16 + pooling_k = 3 + max_patches = 280 * pooling_k * pooling_k # 2520 + target_px = max_patches * patch_size * patch_size + factor = (target_px / (h * w)) ** 0.5 + side_mult = pooling_k * patch_size # 48 + target_h = max(int(factor * h // side_mult) * side_mult, side_mult) + target_w = max(int(factor * w // side_mult) * side_mult, side_mult) + + # Resize via PIL to match HF processor (operates on uint8, not float tensors) + from PIL import Image + import numpy as np + img_uint8 = (samples[0].permute(1, 2, 0).clamp(0, 1) * 255).byte().cpu().numpy() + pil_img = Image.fromarray(img_uint8).resize((target_w, target_h), Image.BICUBIC) + s = torch.from_numpy(np.array(pil_img).astype(np.float32) / 255.0) + s = s.permute(2, 0, 1).unsqueeze(0).to(samples.device) + s = 2 * (s - 0.5) # normalize [0,1] -> [-1,1] + images = [s.movedim(1, -1)[:, :, :, :3]] + + if text.startswith('<|turn>'): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + elif len(audio_features) > 0: + llama_text = f"<|turn>user\n\n<|audio><|audio|>{text}\n<|turn>model\n" + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + if len(images) > 0: + embed_count = 0 + for r in text_tokens: + for i, token in enumerate(r): + if token[0] == 258880 and embed_count < len(images): + r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:] + embed_count += 1 + + if len(audio_features) > 0: + embed_count = 0 + for r in text_tokens: + for i, token in enumerate(r): + if token[0] == 258881 and embed_count < len(audio_features): + r[i] = ({"type": "audio", "data": audio_features[embed_count]},) + token[1:] + embed_count += 1 + + return text_tokens + + +class Gemma4HFTokenizer: + """Wrapper to load GemmaTokenizer from tokenizer.json bytes embedded in safetensors.""" + def __init__(self, tokenizer_json_bytes=None, **kwargs): + import tempfile, os, json + from transformers import AutoTokenizer + self.temp_dir = tempfile.mkdtemp() + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + with open(os.path.join(self.temp_dir, "tokenizer.json"), "wb") as f: + f.write(tokenizer_json_bytes) + # Minimal tokenizer_config.json + with open(os.path.join(self.temp_dir, "tokenizer_config.json"), "w") as f: + json.dump({"tokenizer_class": "GemmaTokenizer", "add_bos_token": True, "add_eos_token": False}, f) + self.tokenizer = AutoTokenizer.from_pretrained(self.temp_dir) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + ids = self.tokenizer.encode(text, add_special_tokens=False) + return {"input_ids": ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return self.tokenizer.convert_tokens_to_ids(tokens) + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, **kwargs) + + +class Gemma4_E4BTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + super().__init__(tokenizer_json, pad_with_end=False, embedding_size=2560, embedding_key='gemma4_e4b', tokenizer_class=Gemma4HFTokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + + +class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4_e4b", tokenizer=Gemma4_E4BTokenizer) + + +class Gemma4_E4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + self.dtypes = set() + self.dtypes.add(dtype) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma4_E4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, embeds_info = super().process_tokens(tokens, device) + scale = self.transformer.model.config.hidden_size ** 0.5 + # Undo text embedding scaling for multimodal tokens (vision/audio) + for info in embeds_info: + start_idx = info["index"] + end_idx = start_idx + info["size"] + embeds[:, start_idx:end_idx, :] /= scale + return embeds + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): + if isinstance(tokens, dict): + tokens = next(iter(tokens.values())) + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device) + # Build input_ids matching embeds length for per-layer embeddings + # HF uses pad_token_id (0) at multimodal positions, not the placeholder ID + base_ids = [t if isinstance(t, int) else 0 for t in tokens_only[0]] + # Expand: each multimodal position was 1 token, now occupies `size` positions + initial_token_ids = [base_ids] + for info in sorted(embeds_info, key=lambda i: i["index"], reverse=True): + idx, size = info["index"], info["size"] + initial_token_ids[0] = initial_token_ids[0][:idx] + [0] * size + initial_token_ids[0][idx + 1:] + scale = self.transformer.model.config.hidden_size ** 0.5 + for info in embeds_info: + start_idx = info["index"] + end_idx = start_idx + info["size"] + embeds[:, start_idx:end_idx, :] /= scale + input_ids = torch.tensor(initial_token_ids, device=self.execution_device) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) + + +def gemma4_te(dtype_llama=None, llama_quantization_metadata=None): + class Gemma4TEModel_(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name="gemma4_e4b", clip_model=Gemma4_E4BModel, model_options=model_options) + return Gemma4TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 06f2fbf74..ad0965161 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -666,7 +666,7 @@ class Llama2_(nn.Module): self.config.rope_dims, device=device) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): if embeds is not None: x = embeds else: @@ -826,7 +826,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): device = embeds.device if stop_tokens is None: @@ -851,14 +851,16 @@ class BaseGenerate: pbar = comfy.utils.ProgressBar(max_length) # Generation loop + current_input_ids = initial_input_ids for step in tqdm(range(max_length), desc="Generating tokens"): - x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) logits = self.logits(x)[:, -1] next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item() generated_token_ids.append(token_id) embeds = self.model.embed_tokens(next_token).to(execution_dtype) + current_input_ids = next_token if initial_input_ids is not None else None pbar.update(1) if token_id in stop_tokens: diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index f1aeb63fa..d2fa48223 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -32,6 +32,7 @@ class TextGenerate(io.ComfyNode): io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.Image.Input("image", optional=True), + io.Audio.Input("audio", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), @@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, audio=None, thinking=False) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, audio=audio, skip_template=False, min_length=1, thinking=thinking) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" From 93e86351109d10c5471d67c045a9fc9c3d860ddf Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 7 Apr 2026 01:15:04 +0300 Subject: [PATCH 02/18] parity with reference implementation outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize --- comfy/ldm/modules/attention.py | 20 +- comfy/rmsnorm.py | 10 +- comfy/sd.py | 18 +- comfy/text_encoders/gemma4.py | 950 +++++++++++++++++++-------------- comfy/text_encoders/llama.py | 28 +- comfy/text_encoders/lt.py | 3 +- comfy/text_encoders/lumina2.py | 3 +- comfy/text_encoders/qwen35.py | 2 - comfy/utils.py | 7 - comfy_extras/nodes_textgen.py | 5 +- 10 files changed, 613 insertions(+), 433 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b193fe5e8..43cecad7f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -150,7 +150,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]: + n_rep = q.shape[-3] // k.shape[-3] + k = k.repeat_interleave(n_rep, dim=-3) + v = v.repeat_interleave(n_rep, dim=-3) + + scale = kwargs.get("scale", dim_head ** -0.5) h = heads if skip_reshape: @@ -219,6 +224,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, b, _, dim_head = query.shape dim_head //= heads + if "scale" in kwargs: + # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head)) + query = query * (kwargs["scale"] * dim_head ** 0.5) + if skip_reshape: query = query.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head) @@ -290,7 +299,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + scale = kwargs.get("scale", dim_head ** -0.5) if skip_reshape: q, k, v = map( @@ -500,8 +509,11 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) + # Pass through extra SDPA kwargs (scale, enable_gqa) if provided + sdpa_extra = {k: v for k, v in kwargs.items() if k in ("scale", "enable_gqa")} + if SDP_BATCH_LIMIT >= b: - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -519,7 +531,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=m, - dropout_p=0.0, is_causal=False + dropout_p=0.0, is_causal=False, **sdpa_extra ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index ab7cf14fa..5e5ef359a 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -3,7 +3,15 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm -def rms_norm(x, weight=None, eps=1e-6): +def rms_norm(x, weight=None, eps=1e-6, fused=True): + if not fused: + orig_dtype = x.dtype + normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5) + if weight is not None: + weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device) + normed = normed * weight + return normed.to(orig_dtype) + if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) else: diff --git a/comfy/sd.py b/comfy/sd.py index 9b1960286..7565e0f9e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1230,6 +1230,8 @@ class TEModel(Enum): QWEN35_9B = 26 QWEN35_27B = 27 GEMMA_4_E4B = 28 + GEMMA_4_E2B = 29 + GEMMA_4_31B = 30 def detect_te_model(sd): @@ -1255,8 +1257,12 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.59.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_4_31B if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: return TEModel.GEMMA_4_E4B + if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E2B if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1280,7 +1286,7 @@ def detect_te_model(sd): if weight.shape[0] == 4096: return TEModel.QWEN35_9B if weight.shape[0] == 5120: - return TEModel.QWEN35_27B + return TEModel.QWEN35_31B return TEModel.QWEN35_2B if "model.layers.0.post_attention_layernorm.weight" in sd: weight = sd['model.layers.0.post_attention_layernorm.weight'] @@ -1395,9 +1401,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer elif te_model == TEModel.GEMMA_4_E4B: - clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E4B) clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) + elif te_model == TEModel.GEMMA_4_E2B: + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E2B) + clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_E2BTokenizerWrapper + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) + elif te_model == TEModel.GEMMA_4_31B: + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_31B) + clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_31BTokenizerWrapper + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index c3a964cc4..ee4c672f2 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1,19 +1,21 @@ import torch import torch.nn as nn +import numpy as np from dataclasses import dataclass +import math from comfy import sd1_clip -import comfy.utils import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device -from comfy.text_encoders.llama import RMSNorm, BaseLlama, BaseGenerate, Llama2_ +from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook GEMMA4_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "model_type": "gemma4_vision", "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_VISION_31B_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "gemma4_vision", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5, "gradient_clipping": 1e10, "hidden_act": "silu"} @dataclass -class Gemma4_E4B_Config: +class Gemma4Config: vocab_size: int = 262144 hidden_size: int = 2560 intermediate_size: int = 10240 @@ -40,44 +42,60 @@ class Gemma4_E4B_Config: final_logit_softcapping: float = 30.0 hidden_size_per_layer_input: int = 256 num_kv_shared_layers: int = 18 - stop_tokens = [1, 106] + use_double_wide_mlp: bool = False + stop_tokens = [1, 50, 106] + fused_rms_norm: bool = False # True = use fused F.rms_norm (~64% faster, minor output difference from reference) vision_config = GEMMA4_VISION_CONFIG audio_config = GEMMA4_AUDIO_CONFIG mm_tokens_per_image = 280 +Gemma4_E4B_Config = Gemma4Config -def precompute_freqs_cis_proportional(head_dim, partial_rotary_factor, position_ids, theta, device=None): - """Proportional RoPE: compute freqs for full head_dim, but only first rope_angles get non-zero frequencies.""" - rope_angles = int(partial_rotary_factor * head_dim // 2) - nope_angles = head_dim // 2 - rope_angles +@dataclass +class Gemma4_E2B_Config(Gemma4Config): + hidden_size: int = 1536 + intermediate_size: int = 6144 + num_hidden_layers: int = 35 + num_key_value_heads: int = 1 + sliding_attention = [512, 512, 512, 512, False] + num_kv_shared_layers: int = 20 + use_double_wide_mlp: bool = True - theta_numerator = torch.arange(0, 2 * rope_angles, 2, device=device).float() - inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) +@dataclass +class Gemma4_31B_Config(Gemma4Config): + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + hidden_size_per_layer_input: int = 0 + num_kv_shared_layers: int = 0 + vision_config = GEMMA4_VISION_31B_CONFIG - if nope_angles > 0: - inv_freq = torch.cat([inv_freq, torch.zeros(nope_angles, device=device)], dim=0) - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().unsqueeze(1) - sin = emb.sin().unsqueeze(1) - sin_split = sin.shape[-1] // 2 - return (cos, sin[..., :sin_split], -sin[..., sin_split:]) +def _apply_rotary_pos_emb(x, freqs_cis): + cos, sin = freqs_cis[0], freqs_cis[1] + half = x.shape[-1] // 2 + out = x * cos + out[..., :half] -= x[..., half:] * sin[..., :half] + out[..., half:] += x[..., :half] * sin[..., half:] + return out + + +def _apply_rope_gemma(xq, xk, freqs_cis): + return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis) class Gemma4Attention(nn.Module): def __init__(self, config, head_dim, device=None, dtype=None, ops=None): super().__init__() - from comfy.text_encoders.llama import RMSNorm self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.hidden_size = config.hidden_size self.head_dim = head_dim self.inner_size = self.num_heads * head_dim - ops = ops or nn self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) @@ -85,22 +103,22 @@ class Gemma4Attention(nn.Module): self.q_norm = None self.k_norm = None + fused = getattr(config, 'fused_rms_norm', False) if config.q_norm == "gemma3": - self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.k_norm == "gemma3": - self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) def forward( self, hidden_states: torch.Tensor, attention_mask=None, freqs_cis=None, - optimized_attention=None, past_key_value=None, sliding_window=None, shared_kv=None, + **kwargs, ): - from comfy.text_encoders.llama import apply_rope batch_size, seq_length, _ = hidden_states.shape xq = self.q_proj(hidden_states) @@ -109,66 +127,58 @@ class Gemma4Attention(nn.Module): xq = self.q_norm(xq) if shared_kv is not None: - # KV-shared layer: borrow KV from source layer, skip own cache - if len(shared_kv) == 3: - xk, xv = shared_kv[0][:, :, :shared_kv[2]], shared_kv[1][:, :, :shared_kv[2]] - else: - xk, xv = shared_kv + xk, xv = shared_kv # Apply RoPE to Q only (K already has RoPE from source layer) - xq, _ = apply_rope(xq, xq, freqs_cis=freqs_cis) # dummy K, only Q result used + xq = _apply_rotary_pos_emb(xq, freqs_cis) present_key_value = None shareable_kv = None else: - xk = self.k_proj(hidden_states) - xv = self.v_proj(hidden_states) - xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) - xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) if self.k_norm is not None: xk = self.k_norm(xk) xv = _parameterless_rms_norm(xv) - xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis) present_key_value = None if past_key_value is not None: - index = 0 - num_tokens = xk.shape[2] + cumulative_len = 0 if len(past_key_value) > 0: - past_key, past_value, index = past_key_value - if past_key.shape[2] >= (index + num_tokens): - past_key[:, :, index:index + xk.shape[2]] = xk - past_value[:, :, index:index + xv.shape[2]] = xv - xk = past_key[:, :, :index + xk.shape[2]] - xv = past_value[:, :, :index + xv.shape[2]] - present_key_value = (past_key, past_value, index + num_tokens) - else: - xk = torch.cat((past_key[:, :, :index], xk), dim=2) - xv = torch.cat((past_value[:, :, :index], xv), dim=2) - present_key_value = (xk, xv, index + num_tokens) + past_key, past_value, cumulative_len = past_key_value + xk = torch.cat((past_key, xk), dim=2) + xv = torch.cat((past_value, xv), dim=2) + new_cumulative = cumulative_len + seq_length + if sliding_window is not None and xk.shape[2] > sliding_window - 1: + cache_k = xk[:, :, -(sliding_window - 1):] + cache_v = xv[:, :, -(sliding_window - 1):] else: - present_key_value = (xk, xv, index + num_tokens) + cache_k = xk + cache_v = xv + present_key_value = (cache_k, cache_v, new_cumulative) - if sliding_window is not None and xk.shape[2] > sliding_window: - xk = xk[:, :, -sliding_window:] - xv = xv[:, :, -sliding_window:] - attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None + # KV for sharing: full xk/xv that SDPA sees (not evicted cache) + shareable_kv = (xk, xv) - # KV for sharing with later layers - shareable_kv = present_key_value if present_key_value is not None else (xk, xv) + # GQA: pass unexpanded KV with enable_gqa when no sliding mask, + # expand heads when sliding mask is present + # has to be done within SDPA itself to match the reference code, pre-scaling expansion causes numerical differences + expand_kv = (self.num_heads != self.num_kv_heads and + sliding_window is not None and + xk.shape[2] >= sliding_window) + if expand_kv: + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + gqa_kwargs = {} if expand_kv else ({"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}) + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0, **gqa_kwargs) - xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) - xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) - - # scaling=1.0: pre-multiply Q to cancel optimized_attention's 1/sqrt(head_dim) - xq = xq * (self.head_dim ** 0.5) - - output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) return self.o_proj(output), present_key_value, shareable_kv class TransformerBlockGemma4(nn.Module): def __init__(self, config, index, device=None, dtype=None, ops=None): super().__init__() - from comfy.text_encoders.llama import MLP if config.sliding_attention is not None: self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] else: @@ -177,31 +187,36 @@ class TransformerBlockGemma4(nn.Module): head_dim = config.head_dim if self.sliding_attention else config.global_head_dim self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) - self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + num_kv_shared = getattr(config, 'num_kv_shared_layers', 0) + first_kv_shared = config.num_hidden_layers - num_kv_shared + mlp_size = config.intermediate_size * 2 if getattr(config, 'use_double_wide_mlp', False) and index >= first_kv_shared else None + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) + + fused = getattr(config, 'fused_rms_norm', False) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) if self.hidden_size_per_layer_input: - ops_pl = ops or nn - self.per_layer_input_gate = ops_pl.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) - self.per_layer_projection = ops_pl.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) - self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) + self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) else: self.layer_scalar = None - def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, - past_key_value=None, per_layer_input=None, shared_kv=None): + def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None): sliding_window = None if self.sliding_attention: sliding_window = self.sliding_attention + # For prefill > sliding window, add sliding window restriction to the causal mask. if x.shape[1] > self.sliding_attention: - sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype) - sliding_mask.tril_(diagonal=-self.sliding_attention) - attention_mask = attention_mask + sliding_mask if attention_mask is not None else sliding_mask + sw_mask = torch.zeros(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(torch.ones_like(sw_mask, dtype=torch.bool).tril_(-self.sliding_attention), torch.finfo(x.dtype).min) + attention_mask = attention_mask + sw_mask if attention_mask is not None else sw_mask freqs_cis = freqs_cis[1] else: freqs_cis = freqs_cis[0] @@ -210,8 +225,7 @@ class TransformerBlockGemma4(nn.Module): x = self.input_layernorm(x) x, present_key_value, shareable_kv = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, - optimized_attention=optimized_attention, past_key_value=past_key_value, - sliding_window=sliding_window, shared_kv=shared_kv, + past_key_value=past_key_value, sliding_window=sliding_window, shared_kv=shared_kv, ) x = self.post_attention_layernorm(x) x = residual + x @@ -237,50 +251,79 @@ class TransformerBlockGemma4(nn.Module): return x, present_key_value, shareable_kv -class Gemma4Transformer(Llama2_): - """Llama2_ subclass with Gemma4-specific features: per-layer inputs, KV sharing, proportional RoPE.""" +class Gemma4Transformer(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): - super().__init__(config, device=device, dtype=dtype, ops=ops) - # Override transformer type - self.normalize_in = True - # Replace layers with Gemma4 blocks + super().__init__() + self.config = config + fused = getattr(config, 'fused_rms_norm', False) + + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) + self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) + self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook) + self.layers = nn.ModuleList([ TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) for i in range(config.num_hidden_layers) ]) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.final_norm else None + + # Precompute RoPE inv_freq on CPU to match reference code's exact value + rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) + nope_global = config.global_head_dim // 2 - rope_angles_global + global_inv = 1.0 / (config.rope_theta[0] ** (torch.arange(0, 2 * rope_angles_global, 2).float() / config.global_head_dim)) + if nope_global > 0: + global_inv = torch.cat([global_inv, torch.zeros(nope_global)]) + self.register_buffer("_global_inv_freq", global_inv, persistent=False) + + sliding_inv = 1.0 / (config.rope_theta[1] ** (torch.arange(0, config.head_dim, 2).float() / config.head_dim)) + self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False) + # Per-layer input mechanism self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) if self.hidden_size_per_layer_input: - self.embed_tokens_per_layer = ops.Embedding( - config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, - device=device, dtype=dtype) + self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype) + self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False) + self.embed_tokens_per_layer.register_forward_hook(_gemma_embed_scale_hook) self.per_layer_model_projection = ops.Linear( config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_projection_norm = RMSNorm( self.hidden_size_per_layer_input, eps=config.rms_norm_eps, - add=config.rms_norm_add, device=device, dtype=dtype) + device=device, dtype=dtype, fused=fused) - def compute_freqs_cis(self, position_ids, device): - from comfy.text_encoders.llama import precompute_freqs_cis - global_freqs = precompute_freqs_cis_proportional( - self.config.global_head_dim, self.config.partial_rotary_factor, - position_ids, self.config.rope_theta[0], device=device) - sliding_freqs = precompute_freqs_cis( - self.config.head_dim, position_ids, self.config.rope_theta[1], device=device) + def get_past_len(self, past_key_values): + for kv in past_key_values: + if len(kv) >= 3: + return kv[2] + return 0 + + def _freqs_from_inv(self, inv_freq, position_ids, dtype=None): + """Compute cos/sin from stored inv_freq""" + inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(position_ids.device) + pos_exp = position_ids[:, None, :].float() + freqs = (inv_exp @ pos_exp).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().unsqueeze(1) + sin = emb.sin().unsqueeze(1) + result = (cos, sin) + if dtype is not None: + result = tuple(t.to(dtype) for t in result) + return result + + def compute_freqs_cis(self, position_ids, device, dtype=None): + global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, dtype) + sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, dtype) return [global_freqs, sliding_freqs] def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, - final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], + final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=None, past_key_values=None, input_ids=None): if embeds is not None: x = embeds else: x = self.embed_tokens(x, out_dtype=dtype) - if self.normalize_in: - x *= self.config.hidden_size ** 0.5 - seq_len = x.shape[1] past_len = 0 if past_key_values is not None and len(past_key_values) > 0: @@ -289,19 +332,19 @@ class Gemma4Transformer(Llama2_): if position_ids is None: position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) - freqs_cis = self.compute_freqs_cis(position_ids, x.device) + freqs_cis = self.compute_freqs_cis(position_ids, x.device, dtype=x.dtype) mask = None + min_val = torch.finfo(x.dtype).min if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) + mask = mask.masked_fill(mask.to(torch.bool), min_val) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) + causal_mask = torch.zeros(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device) + causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val) mask = mask + causal_mask if mask is not None else causal_mask - optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) - # Per-layer inputs per_layer_inputs = None if self.hidden_size_per_layer_input: @@ -310,7 +353,7 @@ class Gemma4Transformer(Llama2_): per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5)) per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl)) if input_ids is not None and input_ids.shape[1] == x.shape[1]: - per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl) * (hpl ** 0.5) + per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl) per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5) else: per_layer_inputs = per_layer_proj @@ -329,20 +372,19 @@ class Gemma4Transformer(Llama2_): layer_kwargs = {} if per_layer_inputs is not None: layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :] + + is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention if i >= first_kv_shared and num_kv_shared > 0: - is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention shared = shared_sliding_kv if is_sliding else shared_global_kv if shared is not None: layer_kwargs['shared_kv'] = shared - x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, - optimized_attention=optimized_attention, past_key_value=past_kv, **layer_kwargs) + x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, past_key_value=past_kv, **layer_kwargs) next_key_values.append(current_kv if current_kv is not None else ()) # Only track the last sliding/global before the sharing boundary if i < first_kv_shared and shareable_kv is not None: - is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention if is_sliding: shared_sliding_kv = shareable_kv else: @@ -359,19 +401,14 @@ class Gemma4Transformer(Llama2_): return x, intermediate -class Gemma4_E4B(BaseLlama, BaseGenerate, torch.nn.Module): - def __init__(self, config_dict, dtype, device, operations): - super().__init__() - config = Gemma4_E4B_Config(**config_dict) +class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): + """Common base for all Gemma4 variants: text model + vision.""" + def _init_model(self, config, dtype, device, operations): self.num_layers = config.num_hidden_layers - self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype - self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype, device, operations) self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype, device, operations) - self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations) - self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations) def logits(self, x): logits = super().logits(x) @@ -381,43 +418,61 @@ class Gemma4_E4B(BaseLlama, BaseGenerate, torch.nn.Module): return logits def init_kv_cache(self, batch, max_cache_len, device, execution_dtype): - config = self.model.config - num_kv_shared = getattr(config, 'num_kv_shared_layers', 0) - first_kv_shared = config.num_hidden_layers - num_kv_shared past_key_values = [] - for i in range(config.num_hidden_layers): - if i >= first_kv_shared: - past_key_values.append(()) # shared layers don't need KV cache - else: - sa = config.sliding_attention[i % len(config.sliding_attention)] - hd = config.head_dim if sa else config.global_head_dim - past_key_values.append(( - torch.empty([batch, config.num_key_value_heads, max_cache_len, hd], device=device, dtype=execution_dtype), - torch.empty([batch, config.num_key_value_heads, max_cache_len, hd], device=device, dtype=execution_dtype), - 0)) + for _ in range(self.model.config.num_hidden_layers): + past_key_values.append(()) return past_key_values def preprocess_embed(self, embed, device): if embed["type"] == "image": image = embed["data"].movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] - vision_out = self.vision_model(image.to(device, dtype=torch.float32)) + max_soft_tokens = embed.get("max_soft_tokens", None) + vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) return self.multi_modal_projector(vision_out), None + if embed["type"] == "video": + frame_idx = embed.get("frame_idx", 0) + if not hasattr(self, '_video_cache') or self._video_cache is None: + # First frame: process all frames as a batch + frames = embed["data"].movedim(-1, 1) # [N, H, W, C] -> [N, C, H, W] + max_soft_tokens = embed.get("max_soft_tokens", None) + vision_out = self.vision_model(frames.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) + projected = self.multi_modal_projector(vision_out) # [N, tokens_per_frame, hidden] + self._video_cache = projected + result = self._video_cache[frame_idx:frame_idx+1] # [1, tokens_per_frame, hidden] + if frame_idx == self._video_cache.shape[0] - 1: + self._video_cache = None # clear after last frame + return result, None + return None, None + + +class Gemma4AudioMixin: + """Adds audio support to a Gemma4 model.""" + def _init_audio(self, config, dtype, device, operations): + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations) + + def preprocess_embed(self, embed, device): + result, extra = super().preprocess_embed(embed, device) + if result is not None: + return result, extra if embed["type"] == "audio": audio = embed["data"].to(device, dtype=torch.float32) - audio_out = self.audio_model(audio) + audio_mask = embed.get("mask", None) + if audio_mask is not None: + audio_mask = audio_mask.to(device) + audio_out = self.audio_model(audio, audio_mask=audio_mask) return self.audio_projector(audio_out), None return None, None -# --- Vision Encoder --- -# Matches HF weight structure after conversion: -# vision_model.patch_embedder.input_proj.weight [768, 768] -# vision_model.patch_embedder.position_embedding_table [2, 10240, 768] -# vision_model.encoder.layers.X.self_attn.{q,k,v,o}_proj.weight [768, 768] -# vision_model.encoder.layers.X.self_attn.{q,k}_norm.weight [64] -# vision_model.encoder.layers.X.mlp.{gate,up}_proj.weight [3072, 768] -# vision_model.encoder.layers.X.mlp.down_proj.weight [768, 3072] -# vision_model.encoder.layers.X.{input,post_attention,pre_feedforward,post_feedforward}_layernorm.weight [768] +class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(Gemma4_E4B_Config(**config_dict), dtype, device, operations) + self._init_audio(self.model.config, dtype, device, operations) + + +# Vision Encoder def _parameterless_rms_norm(x, eps=1e-6): """RMSNorm without learnable weight (used by Gemma4 v_norm and projectors).""" @@ -506,19 +561,15 @@ class ClippedLinear(nn.Module): return x -def _make_clipped_linear(in_f, out_f, bias=False, device=None, dtype=None, operations=None): - return ClippedLinear(in_f, out_f, bias=bias, device=device, dtype=dtype, operations=operations) - - class Gemma4VisionMLP(nn.Module): """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" def __init__(self, config, device=None, dtype=None, operations=None): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config["intermediate_size"] - self.gate_proj = _make_clipped_linear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.up_proj = _make_clipped_linear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.down_proj = _make_clipped_linear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) def forward(self, x): return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) @@ -531,15 +582,16 @@ class Gemma4VisionAttention(nn.Module): self.num_heads = config["num_attention_heads"] self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) - self.q_proj = _make_clipped_linear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.k_proj = _make_clipped_linear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.v_proj = _make_clipped_linear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.o_proj = _make_clipped_linear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) + self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) + self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) + self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations) - def forward(self, x, cos_sin=None, attention_mask=None, optimized_attention=None): + self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + + def forward(self, x, cos_sin=None, attention_mask=None, **kwargs): batch_size, seq_length, _ = x.shape xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) @@ -550,23 +602,18 @@ class Gemma4VisionAttention(nn.Module): xk = self.k_norm(xk) xv = _parameterless_rms_norm(xv) + xq = xq.transpose(1, 2) # [B, H, S, D] + xk = xk.transpose(1, 2) + # Apply 2D RoPE if cos_sin is not None: cos, sin = cos_sin - xq = xq.transpose(1, 2) # [B, H, S, D] - xk = xk.transpose(1, 2) xq = _apply_vision_2d_rope(xq, cos, sin) xk = _apply_vision_2d_rope(xk, cos, sin) - else: - xq = xq.transpose(1, 2) - xk = xk.transpose(1, 2) xv = xv.to(xq.dtype).transpose(1, 2) - # scaling=1.0 (Q/K already normalized), cancel optimized_attention's 1/sqrt(d) - xq = xq * (self.head_dim ** 0.5) - - output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0) return self.o_proj(output) @@ -575,15 +622,17 @@ class Gemma4VisionLayer(nn.Module): super().__init__() self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, operations=operations) self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, operations=operations) - self.input_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.post_attention_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.pre_feedforward_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.post_feedforward_layernorm = RMSNorm(config["hidden_size"], eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + hidden = config["hidden_size"] + self.input_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) + self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) - def forward(self, x, cos_sin=None, attention_mask=None, optimized_attention=None): + def forward(self, x, cos_sin=None, attention_mask=None): residual = x x = self.input_layernorm(x) - x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask, optimized_attention=optimized_attention) + x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask) x = self.post_attention_layernorm(x) x = residual + x @@ -609,28 +658,22 @@ class Gemma4PatchEmbedder(nn.Module): torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) ) - def forward(self, pixel_values, pixel_position_ids): + def forward(self, patches, pixel_position_ids): """ - pixel_values: [B, C, H, W] normalized as 2*(x-0.5) - pixel_position_ids: [B, num_patches, 2] with (x,y) positions + patches: [B, num_patches, 3*patch_size²] in [0,1] range (normalized to [-1,1] inside, matching HF) + pixel_position_ids: [B, num_patches, 2] with (x,y) positions, (-1,-1) for padding """ - batch_size, channels, height, width = pixel_values.shape - patches_h = height // self.patch_size - patches_w = width // self.patch_size + hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype)) - # Extract and flatten patches: [B, num_patches, 3*patch_size^2] - x = pixel_values.reshape(batch_size, channels, patches_h, self.patch_size, patches_w, self.patch_size) - x = x.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, patches_h * patches_w, -1) - - hidden_states = self.input_proj(x.to(self.input_proj.weight.dtype)) - - # Position embeddings via one-hot lookup clamped_positions = pixel_position_ids.clamp(min=0) one_hot = torch.nn.functional.one_hot(clamped_positions, num_classes=self.position_embedding_size) pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) - one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table) # [B, 2, num_patches, pos_size] - position_embeddings = one_hot @ pos_table # [B, 2, num_patches, hidden] - position_embeddings = position_embeddings.sum(dim=1) # [B, num_patches, hidden] + one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table) + position_embeddings = (one_hot @ pos_table).sum(dim=1) + + # Zero out position embeddings for padding patches (matching HF) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) return hidden_states + position_embeddings @@ -658,85 +701,79 @@ class Gemma4VisionEncoder(nn.Module): self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, operations=operations) self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, operations=operations) - def forward(self, pixel_values): + def forward(self, pixel_values, max_soft_tokens=None): """ - pixel_values: [B, C, H, W] in [0, 1] range - Returns: [B, output_tokens, hidden_size] projected vision tokens + pixel_values: [B, C, H, W] in [0,1] range + max_soft_tokens: if provided, pad to max_soft_tokens * k² total patches """ - batch_size, channels, height, width = pixel_values.shape - patches_h = height // self.patch_size - patches_w = width // self.patch_size + batch_size, _, height, width = pixel_values.shape + ps = self.patch_size + k = self.pooling_kernel_size + patches_h, patches_w = height // ps, width // ps num_patches = patches_h * patches_w + output_length = max_soft_tokens if max_soft_tokens is not None else num_patches // (k * k) + n_padding = output_length * k * k - num_patches - # Generate position IDs: grid of (col, row) per patch - # HF processor uses (x=col, y=row) convention for position_ids - rows = torch.arange(patches_h, device=pixel_values.device) - cols = torch.arange(patches_w, device=pixel_values.device) - grid_y, grid_x = torch.meshgrid(rows, cols, indexing='ij') - pixel_position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) # [num_patches, 2] - pixel_position_ids = pixel_position_ids.unsqueeze(0).expand(batch_size, -1, -1) # [B, num_patches, 2] + # Patchify and build position grid + patches = pixel_values.reshape(batch_size, -1, patches_h, ps, patches_w, ps) + patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, num_patches, -1) + grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij') + position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1) - # Patch embedding + position embedding - x = self.patch_embedder(pixel_values, pixel_position_ids) + # Append zero-pixel padding with (-1,-1) positions (matching HF) + if n_padding > 0: + patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1) + position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1) - # Compute 2D RoPE cos/sin for attention - cos_sin = _compute_vision_2d_rope(self.head_dim, pixel_position_ids, device=pixel_values.device) + padding = (position_ids == -1).all(dim=-1) - optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True) + # Embed, encode, pool + x = self.patch_embedder(patches, position_ids) + cos_sin = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) + cos_sin = tuple(t.to(x.dtype) for t in cos_sin) + mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None for layer in self.encoder.layers: - x = layer(x, cos_sin=cos_sin, optimized_attention=optimized_attention) + x = layer(x, cos_sin=cos_sin, attention_mask=mask) - # Pooling: position-aware average pooling matching HF's Gemma4VisionPooler - k = self.pooling_kernel_size # 3 - k_squared = k * k - output_length = num_patches // k_squared - if num_patches != output_length and output_length > 0: - # Assign each patch to a kernel block based on its (col, row) position - kernel_col = pixel_position_ids[:, :, 0] // k # col // k - kernel_row = pixel_position_ids[:, :, 1] // k # row // k - stride = patches_w // k # matches HF's (max_x + 1) // k - kernel_idxs = kernel_col + stride * kernel_row # [B, num_patches] + if n_padding > 0: + x = x.masked_fill(padding.unsqueeze(-1), 0.0) - # One-hot assignment matrix and weighted average - weights = torch.nn.functional.one_hot(kernel_idxs.long(), output_length).float() / k_squared - x = (weights.transpose(1, 2) @ x.float()).to(x.dtype) # [B, output_length, hidden] + # Average pool by spatial position + clamped = position_ids.clamp(min=0) + max_x = clamped[:, :, 0].max(dim=-1, keepdim=True)[0] + 1 + ki = torch.div(clamped, k, rounding_mode="floor") + ki = ki[:, :, 0] + (max_x // k) * ki[:, :, 1] + weights = torch.nn.functional.one_hot(ki.long(), output_length).float() / (k * k) + x = (weights.transpose(1, 2) @ x.float()).to(x.dtype) - # Scale by sqrt(hidden_size) like HF pooler - x = x * self.root_hidden_size - return x + # Strip empty output tokens + valid_out = ~((weights == 0).all(dim=1)) + if valid_out.any() and not valid_out.all(): + x = x[:, valid_out[0]] if batch_size > 1 else x[valid_out].unsqueeze(0) + + return x * self.root_hidden_size -class Gemma4MultiModalProjector(nn.Module): - def __init__(self, config, dtype=None, device=None, operations=None): +class Gemma4RMSNormProjector(nn.Module): + """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio.""" + def __init__(self, in_dim, out_dim, dtype=None, device=None, operations=None): super().__init__() - vision_hidden_size = config.vision_config["hidden_size"] - text_hidden_size = config.hidden_size - self.embedding_projection = operations.Linear(vision_hidden_size, text_hidden_size, bias=False, device=device, dtype=dtype) + self.embedding_projection = operations.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) - def forward(self, vision_outputs): - return self.embedding_projection(_parameterless_rms_norm(vision_outputs)) + def forward(self, x): + return self.embedding_projection(_parameterless_rms_norm(x)) -# --- Audio Encoder --- -# Conformer-style architecture matching HF weight structure after conversion: -# audio_model.subsample_conv_projection.layer0.conv.weight [128, 1, 3, 3] -# audio_model.subsample_conv_projection.layer0.norm.weight [128] -# audio_model.subsample_conv_projection.layer1.conv.weight [32, 128, 3, 3] -# audio_model.subsample_conv_projection.layer1.norm.weight [32] -# audio_model.subsample_conv_projection.input_proj_linear.weight [1024, 1024] -# audio_model.layers.X.feed_forward1.{pre,post}_layer_norm.weight [1024] -# audio_model.layers.X.feed_forward1.ffw_layer_{1,2}.weight [4096/1024, 1024/4096] -# audio_model.layers.X.self_attn.{q,k,v}_proj.weight [1024, 1024] -# audio_model.layers.X.self_attn.post.weight [1024, 1024] -# audio_model.layers.X.self_attn.per_dim_scale [128] -# audio_model.layers.X.self_attn.relative_k_proj.weight [1024, 1024] -# audio_model.layers.X.lconv1d.{linear_start,linear_end}.weight, depthwise_conv1d.weight -# audio_model.layers.X.feed_forward2.* (same as feed_forward1) -# audio_model.output_proj.{weight, bias} +class Gemma4MultiModalProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, operations=None): + super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, operations=operations) + + +# Audio Encoder class Gemma4AudioConvSubsampler(nn.Module): - """2D convolution subsampling for audio features, matching HF Gemma4AudioSubSampleConvProjection.""" + """2D convolution subsampling for audio features""" def __init__(self, config, device=None, dtype=None, operations=None): super().__init__() eps = config.get("rms_norm_eps", 1e-6) @@ -751,18 +788,22 @@ class Gemma4AudioConvSubsampler(nn.Module): # proj_input_dim = (128 // 4) * 32 = 1024 self.input_proj_linear = operations.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) - def forward(self, x): - # x: [batch, time, features] - x = x.unsqueeze(1) # [batch, 1, time, features] - x = self.layer0['conv'](x.to(self.layer0['conv'].weight.dtype)) - x = torch.relu(self.layer0['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) - - x = self.layer1['conv'](x) - x = torch.relu(self.layer1['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + def _conv_layer(self, x, layer, mask): + if mask is not None: + x = x * mask[:, None, :, None].to(x.device) + x = layer['conv'](x.to(layer['conv'].weight.dtype)) + x = torch.relu(layer['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + if mask is not None: + mask = mask[:, ::2] + return x, mask + def forward(self, x, mask=None): + x = x.unsqueeze(1) + x, mask = self._conv_layer(x, self.layer0, mask) + x, mask = self._conv_layer(x, self.layer1, mask) batch_size, _, seq_len, _ = x.shape x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) - return self.input_proj_linear(x) + return self.input_proj_linear(x), mask class Gemma4AudioFeedForward(nn.Module): @@ -771,10 +812,10 @@ class Gemma4AudioFeedForward(nn.Module): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config.get("intermediate_size", hidden_size * 4) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.ffw_layer_1 = _make_clipped_linear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.ffw_layer_2 = _make_clipped_linear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) - self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) self.post_layer_scale = config.get("residual_weight", 0.5) self.gradient_clipping = config.get("gradient_clipping", 1e10) @@ -796,21 +837,18 @@ class Gemma4AudioRelPositionalEncoding(nn.Module): def __init__(self, config, device=None, dtype=None): super().__init__() hidden_size = config["hidden_size"] - chunk_size = config.get("attention_chunk_size", 12) context_left = config.get("attention_context_left", 13) context_right = config.get("attention_context_right", 0) - self.context_size = chunk_size + context_left - 1 + context_right + self.chunk_size = config.get("attention_chunk_size", 12) + self.context_size = self.chunk_size + context_left - 1 + context_right - import math num_timescales = hidden_size // 2 log_inc = math.log(10000.0) / max(num_timescales - 1, 1) - inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).unsqueeze(0).unsqueeze(0) + inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).to(dtype=dtype).unsqueeze(0).unsqueeze(0) self.register_buffer("inv_timescales", inv_timescales, persistent=False) - @torch.no_grad() def forward(self, hidden_states): - chunk_size = 12 # matches HF hardcoded value - positions = torch.arange(chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1) + positions = torch.arange(self.chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1) scaled = positions * self.inv_timescales.to(device=hidden_states.device) return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype) @@ -819,7 +857,6 @@ class Gemma4AudioAttention(nn.Module): """Chunked block attention with relative position bias and softcap.""" def __init__(self, config, device=None, dtype=None, operations=None): super().__init__() - import math self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.head_dim = self.hidden_size // self.num_heads @@ -830,12 +867,12 @@ class Gemma4AudioAttention(nn.Module): self.q_scale = (self.head_dim ** -0.5) / math.log(2) self.k_scale = math.log(1 + math.e) / math.log(2) - self.softcap = config.get("attention_logit_cap", 50.0) + self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False) - self.q_proj = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.k_proj = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.v_proj = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.post = _make_clipped_linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) self.relative_k_proj = operations.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) @@ -844,7 +881,7 @@ class Gemma4AudioAttention(nn.Module): num_blocks = (S + self.chunk_size - 1) // self.chunk_size pad = num_blocks * self.chunk_size - S x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad)) - return x.reshape(B, num_blocks, self.chunk_size, H, D) + return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous() def _extract_block_context(self, x): B, S, H, D = x.shape @@ -860,14 +897,32 @@ class Gemma4AudioAttention(nn.Module): x = x[..., :BS * CS] return x.view(B, H, NB, BS, CS) - def forward(self, x, position_embeddings=None): + def _build_blocked_mask(self, seq_len, num_blocks, device, audio_mask=None): + """Build 5D boolean blocked attention mask (True=attend, False=mask)""" + q = torch.arange(seq_len, device=device) + dist = q[:, None] - q[None, :] + mask = (dist >= 0) & (dist < self.max_past_horizon) + if self.max_future_horizon > 0: + mask = mask | ((dist < 0) & ((-dist) < self.max_future_horizon)) + if audio_mask is not None: + mask = mask & audio_mask[0, None, :].bool() + m = mask[None, None] + # Reshape to blocked 5D matching reference's _convert_4d_mask_to_blocked_5d + p = num_blocks * self.chunk_size - seq_len + m = torch.nn.functional.pad(m, (0, p, 0, p), value=False) + m = m.reshape(1, 1, num_blocks, self.chunk_size, -1) + m = torch.nn.functional.pad(m, (self.max_past_horizon, self.max_future_horizon), value=False) + idx = (torch.arange(num_blocks, device=device) * self.chunk_size)[:, None] + torch.arange(self.context_size, device=device)[None, :] + return m.gather(-1, idx[None, None, :, None, :].expand(1, 1, -1, self.chunk_size, -1)) + + def forward(self, x, position_embeddings=None, attn_mask=None): B, S, _ = x.shape q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim) k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim) v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim) - q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale.float()) + q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale) k = k * self.k_scale q_blocks = self._convert_to_block(q) @@ -888,6 +943,11 @@ class Gemma4AudioAttention(nn.Module): attn_weights = matrix_ac + matrix_bd attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap + # Mask out invalid positions in chunk context (matching reference's masked_fill approach) + if attn_mask is None: + attn_mask = self._build_blocked_mask(S, num_blocks, x.device) + attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), -1e9) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype) out = attn_weights @ v_context.permute(0, 3, 1, 2, 4) out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1) @@ -902,19 +962,19 @@ class Gemma4AudioLConv1d(nn.Module): hidden_size = config["hidden_size"] conv_kernel_size = config.get("conv_kernel_size", 5) self.gradient_clipping = config.get("gradient_clipping", 1e10) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.linear_start = _make_clipped_linear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations) - # Causal conv: left-pad only (no right padding) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations) + # Causal conv: left-pad only self.depthwise_conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 - self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.linear_end = _make_clipped_linear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations) def forward(self, x): residual = x x = self.pre_layer_norm(x) x = self.linear_start(x) - x = torch.nn.functional.glu(x, dim=-1) # standard GLU, not gelu-gated + x = torch.nn.functional.glu(x, dim=-1) x = x.transpose(1, 2) x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) x = self.depthwise_conv1d(x).transpose(1, 2) @@ -930,24 +990,25 @@ class Gemma4AudioLayer(nn.Module): """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" def __init__(self, config, device=None, dtype=None, operations=None): super().__init__() - hidden_size = config["hidden_size"] self.gradient_clipping = config.get("gradient_clipping", 1e10) self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, operations=operations) - self.norm_pre_attn = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) - self.norm_post_attn = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + hidden_size = config["hidden_size"] + self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) + self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, operations=operations) self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) - self.norm_out = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype) + self.norm_out = RMSNorm(hidden_size, **norm_kwargs) - def forward(self, x, position_embeddings=None): + def forward(self, x, position_embeddings=None, attn_mask=None): gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) x = self.feed_forward1(x) residual = x x = torch.clamp(x, -gc, gc) x = self.norm_pre_attn(x) - x = self.self_attn(x, position_embeddings=position_embeddings) + x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask) x = torch.clamp(x, -gc, gc) x = self.norm_post_attn(x) x = x + residual @@ -976,90 +1037,130 @@ class Gemma4AudioEncoder(nn.Module): self.output_proj = operations.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) - def forward(self, audio_features): - x = self.subsample_conv_projection(audio_features) + def forward(self, audio_features, audio_mask=None): + x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask) position_embeddings = self.rel_pos_enc(x) + # Build blocked attention mask once for all layers + attn_mask = self.layers[0].self_attn._build_blocked_mask( + x.shape[1], (x.shape[1] + self.layers[0].self_attn.chunk_size - 1) // self.layers[0].self_attn.chunk_size, + x.device, audio_mask=audio_mask) + for layer in self.layers: - x = layer(x, position_embeddings=position_embeddings) + x = layer(x, position_embeddings=position_embeddings, attn_mask=attn_mask) x = self.output_proj(x) return x -class Gemma4AudioProjector(nn.Module): +class Gemma4AudioProjector(Gemma4RMSNormProjector): def __init__(self, config, dtype=None, device=None, operations=None): - super().__init__() - audio_output_dim = config.get("audio_output_proj_dims", 1536) - text_hidden_size = config.get("text_hidden_size", 2560) - self.embedding_projection = operations.Linear(audio_output_dim, text_hidden_size, bias=False, device=device, dtype=dtype) - - def forward(self, audio_outputs): - return self.embedding_projection(_parameterless_rms_norm(audio_outputs)) + super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, operations=operations) -# --- Tokenizer & Wrappers --- +# Tokenizer and Wrappers class Gemma4_Tokenizer(): def state_dict(self): return {} def _extract_mel_spectrogram(self, waveform, sample_rate): - """Extract log mel spectrogram using HF's Gemma4AudioFeatureExtractor.""" - import torchaudio - from transformers.models.gemma4.feature_extraction_gemma4 import Gemma4AudioFeatureExtractor - if sample_rate != 16000: - waveform = torchaudio.functional.resample(waveform, sample_rate, 16000) + """Extract 128-bin log mel spectrogram. + Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. + """ + # Mix to mono first, then resample to 16kHz if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) - # Convert to numpy for HF feature extractor - audio_np = waveform.squeeze(0).numpy() - fe = Gemma4AudioFeatureExtractor() - result = fe([audio_np], return_tensors='pt') - return result['input_features'][0] # [T, 128] + audio = waveform.squeeze(0).float().numpy() + if sample_rate != 16000: + # import librosa + # audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) + # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (still not full match) + from scipy.signal import resample_poly, firwin + from math import gcd + g = gcd(sample_rate, 16000) + up, down = 16000 // g, sample_rate // g + L = max(up, down) + h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5)) + audio = resample_poly(audio, up, down, window=h).astype(np.float32) + n = len(audio) - def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, llama_template=None, skip_template=True, thinking=False, **kwargs): - if thinking: - self.llama_template = "<|turn>system\n<|think|>\n<|turn>user\n{}\n<|turn>model\n" - self.llama_template_images = "<|turn>system\n<|think|>\n<|turn>user\n\n\n<|image><|image|>\n\n{}\n<|turn>model\n" - else: - self.llama_template = "<|turn>user\n{}\n<|turn>model\n" - self.llama_template_images = "<|turn>user\n\n\n<|image><|image|>\n\n{}\n<|turn>model\n" + # Pad to multiple of 128, build sample-level mask + if n % 128 != 0: + audio = np.pad(audio, (0, 128 - n % 128)) + mask_raw = np.ones(len(audio), dtype=np.float32) + mask_raw[n:] = 0.0 + + # Semicausal padding: 160 zeros prepended + audio = np.pad(audio, (160, 0)) + mask_raw = np.pad(mask_raw, (160, 0)) + + # Extract 321-sample frames via stride tricks, drop last → 320 + nf = (len(audio) - 321) // 160 + 1 + strides = (audio.strides[0] * 160, audio.strides[0]) + frames = np.lib.stride_tricks.as_strided(audio, (nf, 321), strides)[..., :-1].copy() + + # Periodic Hann window, FFT magnitude, mel filterbank, log + window = (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(320) / 320)).astype(np.float32) + magnitude = np.abs(np.fft.rfft(frames * window, n=512, axis=-1)) + mel_fb = self._build_mel_filterbank() + log_mel = np.log(np.matmul(magnitude, mel_fb) + np.float64(0.001)).astype(np.float32) + + # Frame mask: valid when last sample in window is real audio + mask = mask_raw[np.arange(nf) * 160 + 320].astype(bool) + log_mel = log_mel * mask[:, None] + return torch.from_numpy(log_mel), torch.from_numpy(mask) # [T, 128], [T] + + @staticmethod + def _build_mel_filterbank(): + """Build 128-bin HTK mel filterbank [257, 128] for 512-pt FFT at 16kHz.""" + mel_freqs = np.linspace(0.0, 2595.0 * np.log10(1.0 + 8000.0 / 700.0), 130) + filter_freqs = 700.0 * (10.0 ** (mel_freqs / 2595.0) - 1.0) + fft_freqs = np.linspace(0, 16000 // 2, 257) + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs): + self.thinking = thinking # Process audio audio_features = [] if audio is not None: waveform = audio["waveform"].squeeze(0) if isinstance(audio, dict) else audio sample_rate = audio.get("sample_rate", 16000) if isinstance(audio, dict) else 16000 - mel = self._extract_mel_spectrogram(waveform, sample_rate) - audio_features = [mel.unsqueeze(0)] # [1, T, 128] + mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) + audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) - if image is None: - images = [] - else: - samples = image.movedim(-1, 1) # [B, C, H, W] + # Process image/video frames + is_video = video is not None + source = video if is_video else image + images = [] + if source is not None: + samples = source.movedim(-1, 1) # [B, C, H, W] h, w = samples.shape[2], samples.shape[3] - # Aspect-ratio-preserving resize matching HF Gemma4ImageProcessor patch_size = 16 pooling_k = 3 - max_patches = 280 * pooling_k * pooling_k # 2520 + max_soft_tokens = 70 if is_video else 280 # video uses smaller token budget per frame + max_patches = max_soft_tokens * pooling_k * pooling_k target_px = max_patches * patch_size * patch_size factor = (target_px / (h * w)) ** 0.5 - side_mult = pooling_k * patch_size # 48 + side_mult = pooling_k * patch_size target_h = max(int(factor * h // side_mult) * side_mult, side_mult) target_w = max(int(factor * w // side_mult) * side_mult, side_mult) - # Resize via PIL to match HF processor (operates on uint8, not float tensors) - from PIL import Image - import numpy as np - img_uint8 = (samples[0].permute(1, 2, 0).clamp(0, 1) * 255).byte().cpu().numpy() - pil_img = Image.fromarray(img_uint8).resize((target_w, target_h), Image.BICUBIC) - s = torch.from_numpy(np.array(pil_img).astype(np.float32) / 255.0) - s = s.permute(2, 0, 1).unsqueeze(0).to(samples.device) - s = 2 * (s - 0.5) # normalize [0,1] -> [-1,1] - images = [s.movedim(1, -1)[:, :, :, :3]] + import torchvision.transforms.functional as TVF + for i in range(samples.shape[0]): + # recaling to match reference code + s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 + if target_h != h or target_w != w: + s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True) + s = s.float() * (1.0 / 255.0) + images.append({"pixels": s.unsqueeze(0).movedim(1, -1)[:, :, :, :3], "max_soft_tokens": max_soft_tokens}) if text.startswith('<|turn>'): skip_template = True @@ -1067,82 +1168,120 @@ class Gemma4_Tokenizer(): if skip_template: llama_text = text else: - if llama_template is None: - if len(images) > 0: - llama_text = self.llama_template_images.format(text) - elif len(audio_features) > 0: - llama_text = f"<|turn>user\n\n<|audio><|audio|>{text}\n<|turn>model\n" - else: - llama_text = self.llama_template.format(text) - else: + if llama_template is not None: llama_text = llama_template.format(text) + else: + # Build template from modalities present + system = "<|turn>system\n<|think|>\n" if self.thinking else "" + media = "" + if len(images) > 0: + if is_video: + fps = kwargs.get("fps", 24) + media += "\n\n" + for i in range(len(images)): + seconds = i / fps + ts = f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" + sep = "" if i == 0 else " " + media += f"{sep}{ts} <|image><|video|>" + media += "\n\n" + else: + media += "\n\n" + for i in range(len(images)): + if i > 0: + media += "\n\n\n\n" + media += "<|image><|image|>" + media += "\n\n" + if len(audio_features) > 0: + # Compute audio token count (always at 16kHz) + num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1] + _fl = 320 # int(round(16000 * 20.0 / 1000.0)) + _hl = 160 # int(round(16000 * 10.0 / 1000.0)) + _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1 + _t = _nmel + for _ in range(2): + _t = (_t + 2 - 3) // 2 + 1 + n_audio_tokens = min(_t, 750) + media += "<|audio>" + "<|audio|>" * n_audio_tokens + "" + llama_text = f"{system}<|turn>user\n{media}{text}\n<|turn>model\n" text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + def _replace_placeholders(token_list, token_id, embeds): + """Replace first placeholder with embed dict, remove remaining consecutive ones.""" + embed_idx = 0 + i = 0 + while i < len(token_list): + if token_list[i][0] == token_id and embed_idx < len(embeds): + token_list[i] = (embeds[embed_idx],) + token_list[i][1:] + embed_idx += 1 + i += 1 + while i < len(token_list) and token_list[i][0] == token_id: + token_list.pop(i) + else: + i += 1 + if len(images) > 0: - embed_count = 0 - for r in text_tokens: - for i, token in enumerate(r): - if token[0] == 258880 and embed_count < len(images): - r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:] - embed_count += 1 + if is_video: + # Video: batch all frames into one embed dict, each placeholder gets its frame's tokens + all_pixels = torch.cat([img["pixels"] for img in images], dim=0) # [N, H, W, C] + img_embeds = [{"type": "video", "data": all_pixels, "max_soft_tokens": images[0]["max_soft_tokens"], "frame_idx": i} for i in range(len(images))] + for r in text_tokens: + _replace_placeholders(r, 258884, img_embeds) + else: + img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] + for r in text_tokens: + _replace_placeholders(r, 258880, img_embeds) if len(audio_features) > 0: - embed_count = 0 + aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features] for r in text_tokens: - for i, token in enumerate(r): - if token[0] == 258881 and embed_count < len(audio_features): - r[i] = ({"type": "audio", "data": audio_features[embed_count]},) + token[1:] - embed_count += 1 + _replace_placeholders(r, 258881, aud_embeds) return text_tokens -class Gemma4HFTokenizer: - """Wrapper to load GemmaTokenizer from tokenizer.json bytes embedded in safetensors.""" +class _Gemma4Tokenizer: + """Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)""" def __init__(self, tokenizer_json_bytes=None, **kwargs): - import tempfile, os, json - from transformers import AutoTokenizer - self.temp_dir = tempfile.mkdtemp() + from tokenizers import Tokenizer if isinstance(tokenizer_json_bytes, torch.Tensor): tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) - with open(os.path.join(self.temp_dir, "tokenizer.json"), "wb") as f: - f.write(tokenizer_json_bytes) - # Minimal tokenizer_config.json - with open(os.path.join(self.temp_dir, "tokenizer_config.json"), "w") as f: - json.dump({"tokenizer_class": "GemmaTokenizer", "add_bos_token": True, "add_eos_token": False}, f) - self.tokenizer = AutoTokenizer.from_pretrained(self.temp_dir) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) @classmethod def from_pretrained(cls, tokenizer_data, **kwargs): return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) def __call__(self, text): - ids = self.tokenizer.encode(text, add_special_tokens=False) - return {"input_ids": ids} + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} def get_vocab(self): return self.tokenizer.get_vocab() def convert_tokens_to_ids(self, tokens): - return self.tokenizer.convert_tokens_to_ids(tokens) + return [self.tokenizer.token_to_id(t) for t in tokens] def decode(self, ids, **kwargs): - return self.tokenizer.decode(ids, **kwargs) + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) -class Gemma4_E4BTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): +# Tokenizer +class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): + embedding_size = 2560 def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_json = tokenizer_data.get("tokenizer_json", None) - super().__init__(tokenizer_json, pad_with_end=False, embedding_size=2560, embedding_key='gemma4_e4b', tokenizer_class=Gemma4HFTokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): + tokenizer_class = Gemma4SDTokenizer def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4_e4b", tokenizer=Gemma4_E4BTokenizer) + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class) -class Gemma4_E4BModel(sd1_clip.SDClipModel): +# Model wrappers +class Gemma4Model(sd1_clip.SDClipModel): + model_class = None def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) if llama_quantization_metadata is not None: @@ -1150,16 +1289,10 @@ class Gemma4_E4BModel(sd1_clip.SDClipModel): model_options["quantization_metadata"] = llama_quantization_metadata self.dtypes = set() self.dtypes.add(dtype) - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma4_E4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=self.model_class, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) def process_tokens(self, tokens, device): - embeds, _, _, embeds_info = super().process_tokens(tokens, device) - scale = self.transformer.model.config.hidden_size ** 0.5 - # Undo text embedding scaling for multimodal tokens (vision/audio) - for info in embeds_info: - start_idx = info["index"] - end_idx = start_idx + info["size"] - embeds[:, start_idx:end_idx, :] /= scale + embeds, _, _, _ = super().process_tokens(tokens, device) return embeds def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): @@ -1167,24 +1300,26 @@ class Gemma4_E4BModel(sd1_clip.SDClipModel): tokens = next(iter(tokens.values())) tokens_only = [[t[0] for t in b] for b in tokens] embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device) - # Build input_ids matching embeds length for per-layer embeddings - # HF uses pad_token_id (0) at multimodal positions, not the placeholder ID - base_ids = [t if isinstance(t, int) else 0 for t in tokens_only[0]] - # Expand: each multimodal position was 1 token, now occupies `size` positions - initial_token_ids = [base_ids] - for info in sorted(embeds_info, key=lambda i: i["index"], reverse=True): - idx, size = info["index"], info["size"] - initial_token_ids[0] = initial_token_ids[0][:idx] + [0] * size + initial_token_ids[0][idx + 1:] - scale = self.transformer.model.config.hidden_size ** 0.5 - for info in embeds_info: - start_idx = info["index"] - end_idx = start_idx + info["size"] - embeds[:, start_idx:end_idx, :] /= scale + seq_len = embeds.shape[1] + ids = [0] * seq_len + expanded_idx = 0 + embed_map = {info["index"]: info["size"] for info in embeds_info} + for t in tokens_only[0]: + if expanded_idx in embed_map: + expanded_idx += embed_map[expanded_idx] + elif isinstance(t, int): + if expanded_idx < seq_len: + ids[expanded_idx] = t + expanded_idx += 1 + else: + expanded_idx += 1 + initial_token_ids = [ids] input_ids = torch.tensor(initial_token_ids, device=self.execution_device) - return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) -def gemma4_te(dtype_llama=None, llama_quantization_metadata=None): +def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=Gemma4_E4B): + clip_model = type('Gemma4Model_', (Gemma4Model,), {'model_class': model_class}) class Gemma4TEModel_(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -1192,5 +1327,24 @@ def gemma4_te(dtype_llama=None, llama_quantization_metadata=None): model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(device=device, dtype=dtype, name="gemma4_e4b", clip_model=Gemma4_E4BModel, model_options=model_options) + super().__init__(device=device, dtype=dtype, name="gemma4", clip_model=clip_model, model_options=model_options) return Gemma4TEModel_ + + +# Variants: config + model_class + embedding_size +class Gemma4_E2B(Gemma4AudioMixin, Gemma4Base): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(Gemma4_E2B_Config(**config_dict), dtype, device, operations) + self._init_audio(self.model.config, dtype, device, operations) + +class Gemma4_31B(Gemma4Base): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(Gemma4_31B_Config(**config_dict), dtype, device, operations) + +class Gemma4_E2BTokenizerWrapper(Gemma4Tokenizer): + tokenizer_class = type('T', (Gemma4SDTokenizer,), {'embedding_size': 1536}) + +class Gemma4_31BTokenizerWrapper(Gemma4Tokenizer): + tokenizer_class = type('T', (Gemma4SDTokenizer,), {'embedding_size': 5376}) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ad0965161..d2aa59090 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -358,18 +358,19 @@ class Gemma3_12B_Config: stop_tokens = [1, 106] class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): + def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.add = add + self.fused = fused def forward(self, x: torch.Tensor): w = self.weight if self.add: w = w + 1.0 - return comfy.ldm.common_dit.rms_norm(x, w, self.eps) + return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused) @@ -497,7 +498,7 @@ class Attention(nn.Module): else: present_key_value = (xk, xv, index + num_tokens) - if sliding_window is not None and xk.shape[2] > sliding_window: + if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1: xk = xk[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:] attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None @@ -509,12 +510,12 @@ class Attention(nn.Module): return self.o_proj(output), present_key_value class MLP(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None): super().__init__() - ops = ops or nn - self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) + intermediate_size = intermediate_size or config.intermediate_size + self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) if config.mlp_activation == "silu": self.activation = torch.nn.functional.silu elif config.mlp_activation == "gelu_pytorch_tanh": @@ -623,6 +624,10 @@ class TransformerBlockGemma2(nn.Module): return x, present_key_value +def _gemma_embed_scale_hook(module, input, output): + return (output.to(module._embed_scale.dtype) * module._embed_scale).to(output.dtype) + + class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -637,10 +642,10 @@ class Llama2_(nn.Module): ) if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 - self.normalize_in = True + self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) + self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook) else: transformer = TransformerBlock - self.normalize_in = False self.layers = nn.ModuleList([ transformer(config, index=i, device=device, dtype=dtype, ops=ops) @@ -672,9 +677,6 @@ class Llama2_(nn.Module): else: x = self.embed_tokens(x, out_dtype=dtype) - if self.normalize_in: - x *= self.config.hidden_size ** 0.5 - seq_len = x.shape[1] past_len = 0 if past_key_values is not None and len(past_key_values) > 0: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5aee1f4c0..bc5cbae28 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): tokens_only = [[t[0] for t in b] for b in tokens] - embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is class DualLinearProjection(torch.nn.Module): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 01ebdfabe..b1f1dbb9f 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) def process_tokens(self, tokens, device): - embeds, _, _, embeds_info = super().process_tokens(tokens, device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = super().process_tokens(tokens, device) return embeds class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index ce9b07464..d8ed9cd32 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_): nn.Module.__init__(self) self.config = config self.vocab_size = config.vocab_size - self.normalize_in = False - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) diff --git a/comfy/utils.py b/comfy/utils.py index 78c491b98..7b7faad3a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res -def normalize_image_embeddings(embeds, embeds_info, scale_factor): - """Normalize image embeddings to match text embedding scale""" - for info in embeds_info: - if info.get("type") == "image": - start_idx = info["index"] - end_idx = start_idx + info["size"] - embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index d2fa48223..4235fd310 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -32,6 +32,7 @@ class TextGenerate(io.ComfyNode): io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.Image.Input("image", optional=True), + io.Image.Input("video", optional=True, tooltip="Video frames as image batch (1 FPS recommended)."), io.Audio.Input("audio", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), @@ -43,9 +44,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, audio=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, audio=audio, skip_template=False, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, video=video, audio=audio, skip_template=False, min_length=1, thinking=thinking) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" From 05eaceafa195d443610501c5f946a9109ed6acf8 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:37:29 +0300 Subject: [PATCH 03/18] Cleanup, video fixes --- comfy/text_encoders/gemma4.py | 270 ++++++++++++++-------------------- 1 file changed, 114 insertions(+), 156 deletions(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index ee4c672f2..9fac8c66a 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -7,12 +7,13 @@ import math from comfy import sd1_clip import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.rmsnorm import rms_norm from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook -GEMMA4_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "model_type": "gemma4_vision", "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} -GEMMA4_VISION_31B_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "gemma4_vision", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} -GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5, "gradient_clipping": 1e10, "hidden_act": "silu"} +GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} @dataclass class Gemma4Config: @@ -74,6 +75,7 @@ class Gemma4_31B_Config(Gemma4Config): vision_config = GEMMA4_VISION_31B_CONFIG +# unfused RoPE as addcmul_ RoPE diverges from reference code def _apply_rotary_pos_emb(x, freqs_cis): cos, sin = freqs_cis[0], freqs_cis[1] half = x.shape[-1] // 2 @@ -82,7 +84,6 @@ def _apply_rotary_pos_emb(x, freqs_cis): out[..., half:] += x[..., :half] * sin[..., half:] return out - def _apply_rope_gemma(xq, xk, freqs_cis): return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis) @@ -117,7 +118,6 @@ class Gemma4Attention(nn.Module): past_key_value=None, sliding_window=None, shared_kv=None, - **kwargs, ): batch_size, seq_length, _ = hidden_states.shape @@ -137,7 +137,7 @@ class Gemma4Attention(nn.Module): xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) if self.k_norm is not None: xk = self.k_norm(xk) - xv = _parameterless_rms_norm(xv) + xv = rms_norm(xv, fused=False) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis) @@ -298,22 +298,17 @@ class Gemma4Transformer(nn.Module): return kv[2] return 0 - def _freqs_from_inv(self, inv_freq, position_ids, dtype=None): + def _freqs_from_inv(self, inv_freq, position_ids, device, dtype): """Compute cos/sin from stored inv_freq""" - inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(position_ids.device) + inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device) pos_exp = position_ids[:, None, :].float() freqs = (inv_exp @ pos_exp).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().unsqueeze(1) - sin = emb.sin().unsqueeze(1) - result = (cos, sin) - if dtype is not None: - result = tuple(t.to(dtype) for t in result) - return result + return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype) def compute_freqs_cis(self, position_ids, device, dtype=None): - global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, dtype) - sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, dtype) + global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype) + sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype) return [global_freqs, sliding_freqs] def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, @@ -358,7 +353,7 @@ class Gemma4Transformer(nn.Module): else: per_layer_inputs = per_layer_proj - # KV sharing: only last sliding (22) and last global (23) layers store KV for sharing + # KV sharing: later layers reuse KV from the last non-shared sliding/global layer num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0) first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers shared_sliding_kv = None # KV from last non-shared sliding layer @@ -407,8 +402,8 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): self.num_layers = config.num_hidden_layers self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype - self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype, device, operations) - self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype, device, operations) + self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations) + self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations) def logits(self, x): logits = super().logits(x) @@ -425,39 +420,26 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): def preprocess_embed(self, embed, device): if embed["type"] == "image": - image = embed["data"].movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] + image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] max_soft_tokens = embed.get("max_soft_tokens", None) vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) return self.multi_modal_projector(vision_out), None - if embed["type"] == "video": - frame_idx = embed.get("frame_idx", 0) - if not hasattr(self, '_video_cache') or self._video_cache is None: - # First frame: process all frames as a batch - frames = embed["data"].movedim(-1, 1) # [N, H, W, C] -> [N, C, H, W] - max_soft_tokens = embed.get("max_soft_tokens", None) - vision_out = self.vision_model(frames.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) - projected = self.multi_modal_projector(vision_out) # [N, tokens_per_frame, hidden] - self._video_cache = projected - result = self._video_cache[frame_idx:frame_idx+1] # [1, tokens_per_frame, hidden] - if frame_idx == self._video_cache.shape[0] - 1: - self._video_cache = None # clear after last frame - return result, None return None, None class Gemma4AudioMixin: """Adds audio support to a Gemma4 model.""" def _init_audio(self, config, dtype, device, operations): - self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations) - self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations) + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations) def preprocess_embed(self, embed, device): result, extra = super().preprocess_embed(embed, device) if result is not None: return result, extra if embed["type"] == "audio": - audio = embed["data"].to(device, dtype=torch.float32) - audio_mask = embed.get("mask", None) + audio = embed.pop("data").to(device, dtype=torch.float32) + audio_mask = embed.pop("mask", None) if audio_mask is not None: audio_mask = audio_mask.to(device) audio_out = self.audio_model(audio, audio_mask=audio_mask) @@ -474,12 +456,6 @@ class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base): # Vision Encoder -def _parameterless_rms_norm(x, eps=1e-6): - """RMSNorm without learnable weight (used by Gemma4 v_norm and projectors).""" - mean_squared = x.float().pow(2).mean(-1, keepdim=True) + eps - return (x.float() * torch.pow(mean_squared, -0.5)).to(x.dtype) - - def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): """Compute 2D RoPE for vision: separate frequencies for x and y dimensions. @@ -507,16 +483,16 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No return cos, sin -def _apply_vision_2d_rope(x, cos, sin): +def _apply_vision_2d_rope(x, freqs): """Apply 2D RoPE (multidimensional) to vision query/key states. Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently. x: [batch, heads, seq, head_dim] - cos, sin: [batch, seq, head_dim] + freqs: (cos, sin) each [batch, seq, head_dim] """ - cos = cos.unsqueeze(1) # [batch, 1, seq, head_dim] - sin = sin.unsqueeze(1) + cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] + sin = freqs[1].unsqueeze(1) def rotate_half(t): t1 = t[..., :t.shape[-1]//2] @@ -541,9 +517,8 @@ class ClippedLinear(nn.Module): Stores input_max/min and output_max/min as buffers loaded from checkpoint. """ - def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, operations=None): + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None): super().__init__() - ops = operations or nn self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype)) self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) @@ -557,59 +532,51 @@ class ClippedLinear(nn.Module): def forward(self, x): x = x.clamp(min=self.input_min, max=self.input_max) x = self.linear(x) - x = x.clamp(min=self.output_min, max=self.output_max) - return x + return x.clamp_(min=self.output_min, max=self.output_max) class Gemma4VisionMLP(nn.Module): """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config["intermediate_size"] - self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) class Gemma4VisionAttention(nn.Module): - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) - - self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - def forward(self, x, cos_sin=None, attention_mask=None, **kwargs): + def forward(self, x, freqs, attention_mask=None): batch_size, seq_length, _ = x.shape xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) - xq = self.q_norm(xq) - xk = self.k_norm(xk) - xv = _parameterless_rms_norm(xv) + xq = self.q_norm(xq).transpose(1, 2) + xk = self.k_norm(xk).transpose(1, 2) + xv = rms_norm(xv, fused=False) - xq = xq.transpose(1, 2) # [B, H, S, D] - xk = xk.transpose(1, 2) - - # Apply 2D RoPE - if cos_sin is not None: - cos, sin = cos_sin - xq = _apply_vision_2d_rope(xq, cos, sin) - xk = _apply_vision_2d_rope(xk, cos, sin) + xq = _apply_vision_2d_rope(xq, freqs) + xk = _apply_vision_2d_rope(xk, freqs) xv = xv.to(xq.dtype).transpose(1, 2) @@ -618,10 +585,10 @@ class Gemma4VisionAttention(nn.Module): class Gemma4VisionLayer(nn.Module): - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() - self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, operations=operations) - self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, operations=operations) + self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) hidden = config["hidden_size"] self.input_layernorm = RMSNorm(hidden, **norm_kwargs) @@ -629,10 +596,10 @@ class Gemma4VisionLayer(nn.Module): self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) - def forward(self, x, cos_sin=None, attention_mask=None): + def forward(self, x, freqs, attention_mask=None): residual = x x = self.input_layernorm(x) - x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask) + x = self.self_attn(x, freqs, attention_mask=attention_mask) x = self.post_attention_layernorm(x) x = residual + x @@ -646,14 +613,14 @@ class Gemma4VisionLayer(nn.Module): class Gemma4PatchEmbedder(nn.Module): """Patch embedding with learned 2D position embeddings via one-hot lookup.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] patch_size = config["patch_size"] self.patch_size = patch_size self.position_embedding_size = config.get("position_embedding_size", 10240) - self.input_proj = operations.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) + self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) self.position_embedding_table = nn.Parameter( torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) ) @@ -680,16 +647,16 @@ class Gemma4PatchEmbedder(nn.Module): class Gemma4VisionEncoderLayers(nn.Module): """Wrapper to produce state dict keys as encoder.layers.X.*""" - def __init__(self, config, dtype=None, device=None, operations=None): + def __init__(self, config, dtype=None, device=None, ops=None): super().__init__() self.layers = nn.ModuleList([ - Gemma4VisionLayer(config, device=device, dtype=dtype, operations=operations) + Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops) for _ in range(config["num_hidden_layers"]) ]) class Gemma4VisionEncoder(nn.Module): - def __init__(self, config, dtype=None, device=None, operations=None): + def __init__(self, config, dtype=None, device=None, ops=None): super().__init__() self.config = config self.hidden_size = config["hidden_size"] @@ -698,8 +665,8 @@ class Gemma4VisionEncoder(nn.Module): self.pooling_kernel_size = config.get("pooling_kernel_size", 3) self.root_hidden_size = self.hidden_size ** 0.5 - self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, operations=operations) - self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, operations=operations) + self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops) + self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops) def forward(self, pixel_values, max_soft_tokens=None): """ @@ -720,7 +687,7 @@ class Gemma4VisionEncoder(nn.Module): grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij') position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1) - # Append zero-pixel padding with (-1,-1) positions (matching HF) + # Append zero-pixel padding with (-1,-1) positions if n_padding > 0: patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1) position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1) @@ -729,12 +696,12 @@ class Gemma4VisionEncoder(nn.Module): # Embed, encode, pool x = self.patch_embedder(patches, position_ids) - cos_sin = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) - cos_sin = tuple(t.to(x.dtype) for t in cos_sin) + freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) + freqs = tuple(t.to(x.dtype) for t in freqs) mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None for layer in self.encoder.layers: - x = layer(x, cos_sin=cos_sin, attention_mask=mask) + x = layer(x, freqs, attention_mask=mask) if n_padding > 0: x = x.masked_fill(padding.unsqueeze(-1), 0.0) @@ -757,36 +724,36 @@ class Gemma4VisionEncoder(nn.Module): class Gemma4RMSNormProjector(nn.Module): """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio.""" - def __init__(self, in_dim, out_dim, dtype=None, device=None, operations=None): + def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None): super().__init__() - self.embedding_projection = operations.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) + self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) def forward(self, x): - return self.embedding_projection(_parameterless_rms_norm(x)) + return self.embedding_projection(rms_norm(x, fused=False)) class Gemma4MultiModalProjector(Gemma4RMSNormProjector): - def __init__(self, config, dtype=None, device=None, operations=None): - super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, operations=operations) + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) # Audio Encoder class Gemma4AudioConvSubsampler(nn.Module): """2D convolution subsampling for audio features""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() eps = config.get("rms_norm_eps", 1e-6) self.layer0 = nn.ModuleDict({ - 'conv': operations.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), - 'norm': operations.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), }) self.layer1 = nn.ModuleDict({ - 'conv': operations.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), - 'norm': operations.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), }) # proj_input_dim = (128 // 4) * 32 = 1024 - self.input_proj_linear = operations.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) + self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) def _conv_layer(self, x, layer, mask): if mask is not None: @@ -807,26 +774,22 @@ class Gemma4AudioConvSubsampler(nn.Module): class Gemma4AudioFeedForward(nn.Module): - """Conformer feed-forward with gradient clipping and residual scaling.""" - def __init__(self, config, device=None, dtype=None, operations=None): + """Conformer feed-forward with residual scaling.""" + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config.get("intermediate_size", hidden_size * 4) self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) self.post_layer_scale = config.get("residual_weight", 0.5) - self.gradient_clipping = config.get("gradient_clipping", 1e10) def forward(self, x): residual = x - gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) - x = torch.clamp(x, -gc, gc) x = self.pre_layer_norm(x) x = torch.nn.functional.silu(self.ffw_layer_1(x)) x = self.ffw_layer_2(x) - x = torch.clamp(x, -gc, gc) x = self.post_layer_norm(x) x = x * self.post_layer_scale return x + residual @@ -855,7 +818,7 @@ class Gemma4AudioRelPositionalEncoding(nn.Module): class Gemma4AudioAttention(nn.Module): """Chunked block attention with relative position bias and softcap.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] @@ -869,12 +832,12 @@ class Gemma4AudioAttention(nn.Module): self.k_scale = math.log(1 + math.e) / math.log(2) self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False) - self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) - self.relative_k_proj = operations.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) + self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) def _convert_to_block(self, x): B, S, H, D = x.shape @@ -884,7 +847,6 @@ class Gemma4AudioAttention(nn.Module): return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous() def _extract_block_context(self, x): - B, S, H, D = x.shape x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)) x = x.unfold(1, self.context_size, self.chunk_size) return torch.movedim(x, -1, 2).contiguous() @@ -907,7 +869,7 @@ class Gemma4AudioAttention(nn.Module): if audio_mask is not None: mask = mask & audio_mask[0, None, :].bool() m = mask[None, None] - # Reshape to blocked 5D matching reference's _convert_4d_mask_to_blocked_5d + # Reshape to blocked 5D matching reference code p = num_blocks * self.chunk_size - seq_len m = torch.nn.functional.pad(m, (0, p, 0, p), value=False) m = m.reshape(1, 1, num_blocks, self.chunk_size, -1) @@ -957,18 +919,17 @@ class Gemma4AudioAttention(nn.Module): class Gemma4AudioLConv1d(nn.Module): """Lightweight convolution with standard GLU.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] conv_kernel_size = config.get("conv_kernel_size", 5) - self.gradient_clipping = config.get("gradient_clipping", 1e10) self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations) + self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) # Causal conv: left-pad only - self.depthwise_conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) + self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): residual = x @@ -978,8 +939,6 @@ class Gemma4AudioLConv1d(nn.Module): x = x.transpose(1, 2) x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) x = self.depthwise_conv1d(x).transpose(1, 2) - gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) - x = torch.clamp(x, -gc, gc) x = self.conv_norm(x) x = torch.nn.functional.silu(x) x = self.linear_end(x) @@ -988,54 +947,49 @@ class Gemma4AudioLConv1d(nn.Module): class Gemma4AudioLayer(nn.Module): """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() - self.gradient_clipping = config.get("gradient_clipping", 1e10) - self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) - self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, operations=operations) + self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) hidden_size = config["hidden_size"] self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) - self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, operations=operations) - self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) + self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops) + self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) self.norm_out = RMSNorm(hidden_size, **norm_kwargs) def forward(self, x, position_embeddings=None, attn_mask=None): - gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) x = self.feed_forward1(x) residual = x - x = torch.clamp(x, -gc, gc) x = self.norm_pre_attn(x) x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask) - x = torch.clamp(x, -gc, gc) x = self.norm_post_attn(x) x = x + residual x = self.lconv1d(x) x = self.feed_forward2(x) - x = torch.clamp(x, -gc, gc) x = self.norm_out(x) return x class Gemma4AudioEncoder(nn.Module): - def __init__(self, config, dtype=None, device=None, operations=None): + def __init__(self, config, dtype=None, device=None, ops=None): super().__init__() self.hidden_size = config["hidden_size"] self.output_proj_dims = config.get("output_proj_dims", 1536) - self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, operations=operations) + self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops) self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype) self.layers = nn.ModuleList([ - Gemma4AudioLayer(config, device=device, dtype=dtype, operations=operations) + Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops) for _ in range(config["num_hidden_layers"]) ]) - self.output_proj = operations.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) + self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) def forward(self, audio_features, audio_mask=None): x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask) @@ -1054,8 +1008,8 @@ class Gemma4AudioEncoder(nn.Module): class Gemma4AudioProjector(Gemma4RMSNormProjector): - def __init__(self, config, dtype=None, device=None, operations=None): - super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, operations=operations) + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops) # Tokenizer and Wrappers @@ -1131,8 +1085,8 @@ class Gemma4_Tokenizer(): # Process audio audio_features = [] if audio is not None: - waveform = audio["waveform"].squeeze(0) if isinstance(audio, dict) else audio - sample_rate = audio.get("sample_rate", 16000) if isinstance(audio, dict) else 16000 + waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio + sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) @@ -1142,6 +1096,18 @@ class Gemma4_Tokenizer(): images = [] if source is not None: samples = source.movedim(-1, 1) # [B, C, H, W] + num_frames = samples.shape[0] + + # Subsample video to 1fps + if is_video: + fps = kwargs.get("fps", 24) + step = max(1, round(fps)) + indices = list(range(0, num_frames, step)) + if len(indices) == 0: + indices = [0] + samples = samples[indices] + num_frames = len(indices) + h, w = samples.shape[2], samples.shape[3] patch_size = 16 pooling_k = 3 @@ -1154,8 +1120,8 @@ class Gemma4_Tokenizer(): target_w = max(int(factor * w // side_mult) * side_mult, side_mult) import torchvision.transforms.functional as TVF - for i in range(samples.shape[0]): - # recaling to match reference code + for i in range(num_frames): + # rescaling to match reference code s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 if target_h != h or target_w != w: s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True) @@ -1176,11 +1142,9 @@ class Gemma4_Tokenizer(): media = "" if len(images) > 0: if is_video: - fps = kwargs.get("fps", 24) media += "\n\n" for i in range(len(images)): - seconds = i / fps - ts = f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" + ts = f"{int(i // 60):02d}:{int(i % 60):02d}" sep = "" if i == 0 else " " media += f"{sep}{ts} <|image><|video|>" media += "\n\n" @@ -1221,16 +1185,10 @@ class Gemma4_Tokenizer(): i += 1 if len(images) > 0: - if is_video: - # Video: batch all frames into one embed dict, each placeholder gets its frame's tokens - all_pixels = torch.cat([img["pixels"] for img in images], dim=0) # [N, H, W, C] - img_embeds = [{"type": "video", "data": all_pixels, "max_soft_tokens": images[0]["max_soft_tokens"], "frame_idx": i} for i in range(len(images))] - for r in text_tokens: - _replace_placeholders(r, 258884, img_embeds) - else: - img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] - for r in text_tokens: - _replace_placeholders(r, 258880, img_embeds) + img_token_id = 258884 if is_video else 258880 + img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] + for r in text_tokens: + _replace_placeholders(r, img_token_id, img_embeds) if len(audio_features) > 0: aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features] From 6718be09bae34e941aa62f5447891099668cb12f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 10 Apr 2026 15:28:26 +0300 Subject: [PATCH 04/18] cleanup, enable fused rms norm by default --- comfy/rmsnorm.py | 2 +- comfy/text_encoders/gemma4.py | 22 ++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 5e5ef359a..af0978341 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -4,7 +4,7 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm def rms_norm(x, weight=None, eps=1e-6, fused=True): - if not fused: + if not fused: # compatibility mode as torch native rms_norm results are slightly different orig_dtype = x.dtype normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5) if weight is not None: diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 9fac8c66a..1442f63a7 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -45,13 +45,11 @@ class Gemma4Config: num_kv_shared_layers: int = 18 use_double_wide_mlp: bool = False stop_tokens = [1, 50, 106] - fused_rms_norm: bool = False # True = use fused F.rms_norm (~64% faster, minor output difference from reference) + fused_rms_norm: bool = True # True = use fused F.rms_norm (lot faster, minor output difference from reference) vision_config = GEMMA4_VISION_CONFIG audio_config = GEMMA4_AUDIO_CONFIG mm_tokens_per_image = 280 -Gemma4_E4B_Config = Gemma4Config - @dataclass class Gemma4_E2B_Config(Gemma4Config): hidden_size: int = 1536 @@ -104,7 +102,7 @@ class Gemma4Attention(nn.Module): self.q_norm = None self.k_norm = None - fused = getattr(config, 'fused_rms_norm', False) + fused = config.fused_rms_norm if config.q_norm == "gemma3": self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.k_norm == "gemma3": @@ -188,18 +186,18 @@ class TransformerBlockGemma4(nn.Module): self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) - num_kv_shared = getattr(config, 'num_kv_shared_layers', 0) + num_kv_shared = config.num_kv_shared_layers first_kv_shared = config.num_hidden_layers - num_kv_shared - mlp_size = config.intermediate_size * 2 if getattr(config, 'use_double_wide_mlp', False) and index >= first_kv_shared else None + mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) - fused = getattr(config, 'fused_rms_norm', False) + fused = config.fused_rms_norm self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) @@ -255,7 +253,7 @@ class Gemma4Transformer(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config - fused = getattr(config, 'fused_rms_norm', False) + fused = config.fused_rms_norm self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) @@ -280,7 +278,7 @@ class Gemma4Transformer(nn.Module): self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False) # Per-layer input mechanism - self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype) self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False) @@ -354,7 +352,7 @@ class Gemma4Transformer(nn.Module): per_layer_inputs = per_layer_proj # KV sharing: later layers reuse KV from the last non-shared sliding/global layer - num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0) + num_kv_shared = self.config.num_kv_shared_layers first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers shared_sliding_kv = None # KV from last non-shared sliding layer shared_global_kv = None # KV from last non-shared global layer @@ -450,7 +448,7 @@ class Gemma4AudioMixin: class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base): def __init__(self, config_dict, dtype, device, operations): super().__init__() - self._init_model(Gemma4_E4B_Config(**config_dict), dtype, device, operations) + self._init_model(Gemma4Config(**config_dict), dtype, device, operations) self._init_audio(self.model.config, dtype, device, operations) From 6b803abe5ac8ef454530a837c7533adebb7b6f8b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 16:30:49 +0300 Subject: [PATCH 05/18] update comment --- comfy/text_encoders/gemma4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 1442f63a7..8ebae8067 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -45,7 +45,7 @@ class Gemma4Config: num_kv_shared_layers: int = 18 use_double_wide_mlp: bool = False stop_tokens = [1, 50, 106] - fused_rms_norm: bool = True # True = use fused F.rms_norm (lot faster, minor output difference from reference) + fused_rms_norm: bool = True # False: to match reference code's exact numerical behavior, which is much slower, so we default to True vision_config = GEMMA4_VISION_CONFIG audio_config = GEMMA4_AUDIO_CONFIG mm_tokens_per_image = 280 From 17ccff25bec9fbe5b148c5aa85ca7dee01fe3f7c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 16:53:12 +0300 Subject: [PATCH 06/18] Cleanup --- comfy/sd.py | 17 +++----- comfy/text_encoders/gemma4.py | 76 ++++++++++++++++------------------- 2 files changed, 41 insertions(+), 52 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 7565e0f9e..3d5d738d0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1400,17 +1400,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer - elif te_model == TEModel.GEMMA_4_E4B: - clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E4B) - clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer - tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) - elif te_model == TEModel.GEMMA_4_E2B: - clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E2B) - clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_E2BTokenizerWrapper - tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) - elif te_model == TEModel.GEMMA_4_31B: - clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_31B) - clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_31BTokenizerWrapper + elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): + variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, + TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, + TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model] + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant) + clip_target.tokenizer = variant.tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 8ebae8067..9c2004a46 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -70,6 +70,7 @@ class Gemma4_31B_Config(Gemma4Config): sliding_attention = [1024, 1024, 1024, 1024, 1024, False] hidden_size_per_layer_input: int = 0 num_kv_shared_layers: int = 0 + audio_config = None vision_config = GEMMA4_VISION_31B_CONFIG @@ -82,10 +83,6 @@ def _apply_rotary_pos_emb(x, freqs_cis): out[..., half:] += x[..., :half] * sin[..., half:] return out -def _apply_rope_gemma(xq, xk, freqs_cis): - return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis) - - class Gemma4Attention(nn.Module): def __init__(self, config, head_dim, device=None, dtype=None, ops=None): super().__init__() @@ -138,7 +135,8 @@ class Gemma4Attention(nn.Module): xv = rms_norm(xv, fused=False) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) - xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + xk = _apply_rotary_pos_emb(xk, freqs_cis) present_key_value = None if past_key_value is not None: @@ -445,13 +443,6 @@ class Gemma4AudioMixin: return None, None -class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base): - def __init__(self, config_dict, dtype, device, operations): - super().__init__() - self._init_model(Gemma4Config(**config_dict), dtype, device, operations) - self._init_audio(self.model.config, dtype, device, operations) - - # Vision Encoder def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): @@ -559,8 +550,8 @@ class Gemma4VisionAttention(nn.Module): self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) - self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) def forward(self, x, freqs, attention_mask=None): batch_size, seq_length, _ = x.shape @@ -587,7 +578,7 @@ class Gemma4VisionLayer(nn.Module): super().__init__() self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) - norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) hidden = config["hidden_size"] self.input_layernorm = RMSNorm(hidden, **norm_kwargs) self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) @@ -741,7 +732,7 @@ class Gemma4AudioConvSubsampler(nn.Module): """2D convolution subsampling for audio features""" def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() - eps = config.get("rms_norm_eps", 1e-6) + eps = config["rms_norm_eps"] self.layer0 = nn.ModuleDict({ 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), @@ -777,10 +768,10 @@ class Gemma4AudioFeedForward(nn.Module): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config.get("intermediate_size", hidden_size * 4) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) - self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.post_layer_scale = config.get("residual_weight", 0.5) def forward(self, x): @@ -921,12 +912,12 @@ class Gemma4AudioLConv1d(nn.Module): super().__init__() hidden_size = config["hidden_size"] conv_kernel_size = config.get("conv_kernel_size", 5) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) # Causal conv: left-pad only self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 - self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): @@ -949,7 +940,7 @@ class Gemma4AudioLayer(nn.Module): super().__init__() self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) - norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) hidden_size = config["hidden_size"] self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) @@ -1239,10 +1230,6 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): class Gemma4Model(sd1_clip.SDClipModel): model_class = None def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): - llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) - if llama_quantization_metadata is not None: - model_options = model_options.copy() - model_options["quantization_metadata"] = llama_quantization_metadata self.dtypes = set() self.dtypes.add(dtype) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=self.model_class, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -1274,7 +1261,7 @@ class Gemma4Model(sd1_clip.SDClipModel): return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) -def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=Gemma4_E4B): +def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None): clip_model = type('Gemma4Model_', (Gemma4Model,), {'model_class': model_class}) class Gemma4TEModel_(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): @@ -1287,20 +1274,27 @@ def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=Ge return Gemma4TEModel_ -# Variants: config + model_class + embedding_size -class Gemma4_E2B(Gemma4AudioMixin, Gemma4Base): - def __init__(self, config_dict, dtype, device, operations): - super().__init__() - self._init_model(Gemma4_E2B_Config(**config_dict), dtype, device, operations) - self._init_audio(self.model.config, dtype, device, operations) +# Variants -class Gemma4_31B(Gemma4Base): - def __init__(self, config_dict, dtype, device, operations): - super().__init__() - self._init_model(Gemma4_31B_Config(**config_dict), dtype, device, operations) +def _make_variant(config_cls): + audio = config_cls.audio_config is not None + bases = (Gemma4AudioMixin, Gemma4Base) if audio else (Gemma4Base,) + class Variant(*bases): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(config_cls(**config_dict), dtype, device, operations) + if audio: + self._init_audio(self.model.config, dtype, device, operations) + embedding_size = config_cls.hidden_size + if embedding_size != Gemma4SDTokenizer.embedding_size: + tok_cls = type('T', (Gemma4SDTokenizer,), {'embedding_size': embedding_size}) + class Tokenizer(Gemma4Tokenizer): + tokenizer_class = tok_cls + Variant.tokenizer = Tokenizer + else: + Variant.tokenizer = Gemma4Tokenizer + return Variant -class Gemma4_E2BTokenizerWrapper(Gemma4Tokenizer): - tokenizer_class = type('T', (Gemma4SDTokenizer,), {'embedding_size': 1536}) - -class Gemma4_31BTokenizerWrapper(Gemma4Tokenizer): - tokenizer_class = type('T', (Gemma4SDTokenizer,), {'embedding_size': 5376}) +Gemma4_E4B = _make_variant(Gemma4Config) +Gemma4_E2B = _make_variant(Gemma4_E2B_Config) +Gemma4_31B = _make_variant(Gemma4_31B_Config) From abff4cf3d629b115e336dda183217327b8b7836d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 16:55:37 +0300 Subject: [PATCH 07/18] Update sd.py --- comfy/sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b0e467144..06f46211e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1238,9 +1238,9 @@ class TEModel(Enum): QWEN35_9B = 26 QWEN35_27B = 27 MINISTRAL_3_3B = 28 - GEMMA_4_E4B = 28 - GEMMA_4_E2B = 29 - GEMMA_4_31B = 30 + GEMMA_4_E4B = 29 + GEMMA_4_E2B = 30 + GEMMA_4_31B = 31 def detect_te_model(sd): From 0fc398a821bc7362ea2812b6db3d8f42f50af7e3 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:11:07 +0300 Subject: [PATCH 08/18] Various fixes --- comfy/ldm/modules/attention.py | 6 +++++- comfy/sd.py | 2 +- comfy/text_encoders/gemma4.py | 5 +++++ comfy_extras/nodes_textgen.py | 4 ++-- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 43cecad7f..a68cb8439 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5) + if model_management.xformers_enabled(): import xformers import xformers.ops @@ -510,7 +512,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) # Pass through extra SDPA kwargs (scale, enable_gqa) if provided - sdpa_extra = {k: v for k, v in kwargs.items() if k in ("scale", "enable_gqa")} + # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above + sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",) + sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys} if SDP_BATCH_LIMIT >= b: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra) diff --git a/comfy/sd.py b/comfy/sd.py index 06f46211e..3c19a4bb6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1295,7 +1295,7 @@ def detect_te_model(sd): if weight.shape[0] == 4096: return TEModel.QWEN35_9B if weight.shape[0] == 5120: - return TEModel.QWEN35_31B + return TEModel.QWEN35_27B return TEModel.QWEN35_2B if "model.layers.0.post_attention_layernorm.weight" in sd: weight = sd['model.layers.0.post_attention_layernorm.weight'] diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 9c2004a46..9573cd427 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1004,7 +1004,11 @@ class Gemma4AudioProjector(Gemma4RMSNormProjector): # Tokenizer and Wrappers class Gemma4_Tokenizer(): + tokenizer_json_data = None + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} return {} def _extract_mel_spectrogram(self, waveform, sample_rate): @@ -1217,6 +1221,7 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): embedding_size = 2560 def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 4235fd310..b4f793f9a 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -162,12 +162,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking) class TextgenExtension(ComfyExtension): From ba3a484c065e8803bb2042ac3d72fcd0d42b4386 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:16:21 +0300 Subject: [PATCH 09/18] Add fp8 scaled embedding support --- comfy/ops.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index b5cd1d47e..b33fde1aa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1159,6 +1159,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self + class Embedding(manual_cast.Embedding): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + weight_key = f"{prefix}weight" + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + # Only fp8 makes sense for embeddings (per-row dequant via index select). + # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. + quant_format = layer_conf.get("format", None) if layer_conf is not None else None + if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + self.quant_format = quant_format + qconfig = QUANT_ALGOS[quant_format] + layout_cls = get_layout_class(qconfig["comfy_tensor_layout"]) + weight = state_dict.pop(weight_key) + manually_loaded_keys = [weight_key] + + scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(scale_key, None) + if scale is not None: + scale = scale.float() + manually_loaded_keys.append(scale_key) + + params = layout_cls.Params( + scale=scale if scale is not None else torch.ones((), dtype=torch.float32), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.num_embeddings, self.embedding_dim), + ) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), + requires_grad=False) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) + else: + if layer_conf is not None: + state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + if destination is not None: + sd = destination + else: + sd = {} + + if not hasattr(self, 'weight') or self.weight is None: + return sd + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + + quant_conf = {"format": self.quant_format} + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + else: + sd["{}weight".format(prefix)] = self.weight + return sd + + def forward_comfy_cast_weights(self, input, out_dtype=None): + weight = self.weight + + # Optimized path: lookup in fp8, dequantize only the selected rows. + if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0: + qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True) + if isinstance(qdata, QuantizedTensor): + scale = qdata._params.scale + qdata = qdata._qdata + else: + scale = None + + x = torch.nn.functional.embedding( + input, qdata, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + uncast_bias_weight(self, qdata, None, offload_stream) + target_dtype = out_dtype if out_dtype is not None else weight.params.orig_dtype + x = x.to(dtype=target_dtype) + if scale is not None and scale != 1.0: + x = x * scale.to(dtype=target_dtype) + return x + + # Fallback for non-quantized or weight_function (LoRA) case + return super().forward_comfy_cast_weights(input, out_dtype=out_dtype) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): From 387e8d8a4c6e068c7e2958fd5e69b63a278ee1d1 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:27:26 +0300 Subject: [PATCH 10/18] small fixes --- comfy/ops.py | 2 +- comfy/text_encoders/gemma4.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index b33fde1aa..9f9041e69 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1237,7 +1237,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input, qdata, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) uncast_bias_weight(self, qdata, None, offload_stream) - target_dtype = out_dtype if out_dtype is not None else weight.params.orig_dtype + target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype x = x.to(dtype=target_dtype) if scale is not None and scale != 1.0: x = x * scale.to(dtype=target_dtype) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 9573cd427..8905f375f 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1073,7 +1073,6 @@ class Gemma4_Tokenizer(): return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs): - self.thinking = thinking # Process audio audio_features = [] @@ -1131,7 +1130,7 @@ class Gemma4_Tokenizer(): llama_text = llama_template.format(text) else: # Build template from modalities present - system = "<|turn>system\n<|think|>\n" if self.thinking else "" + system = "<|turn>system\n<|think|>\n" if thinking else "" media = "" if len(images) > 0: if is_video: From c857b6c65736e45fd9a8d6618c78e712fe0bc96a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 21:25:19 +0300 Subject: [PATCH 11/18] Translate think tokens --- comfy_extras/nodes_textgen.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index b4f793f9a..0d4cf3a2b 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -71,7 +71,15 @@ class TextGenerate(io.ComfyNode): seed=seed ) - generated_text = clip.decode(generated_ids, skip_special_tokens=True) + generated_text = clip.decode(generated_ids, skip_special_tokens=not thinking) + + if thinking: + # Translate Gemma4 thinking channel markers to standard / tags + generated_text = generated_text.replace("<|channel>thought\n", "\n") + generated_text = generated_text.replace("", "") + # Strip remaining special tokens + generated_text = generated_text.replace("", "").replace("", "").strip() + return io.NodeOutput(generated_text) From e0cccbd4c919465cccf7846baff39b394e5638d5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:29:27 +0300 Subject: [PATCH 12/18] Fix image encoder attention mask type So it works with basic attention --- comfy/text_encoders/gemma4.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 8905f375f..7c3df9c09 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -687,7 +687,11 @@ class Gemma4VisionEncoder(nn.Module): x = self.patch_embedder(patches, position_ids) freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) freqs = tuple(t.to(x.dtype) for t in freqs) - mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None + if n_padding > 0: + mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) + mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min) + else: + mask = None for layer in self.encoder.layers: x = layer(x, freqs, attention_mask=mask) From 845eb1442552ea63ebd30b8003204fa10d43ab49 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:40:52 +0300 Subject: [PATCH 13/18] Handle thinking tokens different only for Gemma4 --- comfy/text_encoders/gemma4.py | 9 +++++++++ comfy_extras/nodes_textgen.py | 9 +-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 7c3df9c09..78ad81741 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1227,6 +1227,15 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): self.tokenizer_json_data = tokenizer_json super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + def decode(self, token_ids, **kwargs): + text = super().decode(token_ids, skip_special_tokens=False) + # Translate thinking channel markers to standard / tags + text = text.replace("<|channel>thought\n", "\n") + text = text.replace("", "") + # Strip remaining special tokens + text = text.replace("", "").replace("", "").strip() + return text + class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): tokenizer_class = Gemma4SDTokenizer diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 0d4cf3a2b..ec81159d3 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -71,14 +71,7 @@ class TextGenerate(io.ComfyNode): seed=seed ) - generated_text = clip.decode(generated_ids, skip_special_tokens=not thinking) - - if thinking: - # Translate Gemma4 thinking channel markers to standard / tags - generated_text = generated_text.replace("<|channel>thought\n", "\n") - generated_text = generated_text.replace("", "") - # Strip remaining special tokens - generated_text = generated_text.replace("", "").replace("", "").strip() + generated_text = clip.decode(generated_ids) return io.NodeOutput(generated_text) From 80af0327621fa9e6e40c7741be08caf06dafd293 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:44:27 +0300 Subject: [PATCH 14/18] Code cleanup --- comfy/text_encoders/gemma4.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 78ad81741..68d67ef05 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -475,30 +475,17 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No def _apply_vision_2d_rope(x, freqs): """Apply 2D RoPE (multidimensional) to vision query/key states. - Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently. + Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently. x: [batch, heads, seq, head_dim] freqs: (cos, sin) each [batch, seq, head_dim] """ cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] sin = freqs[1].unsqueeze(1) - - def rotate_half(t): - t1 = t[..., :t.shape[-1]//2] - t2 = t[..., t.shape[-1]//2:] - return torch.cat((-t2, t1), dim=-1) - - # Split into 2 parts (y and x dimensions) half = x.shape[-1] // 2 - x_parts = [x[..., :half], x[..., half:]] - cos_parts = [cos[..., :half], cos[..., half:]] - sin_parts = [sin[..., :half], sin[..., half:]] - - rotated_parts = [] - for xp, cp, sp in zip(x_parts, cos_parts, sin_parts): - rotated_parts.append((xp * cp) + (rotate_half(xp) * sp)) - - return torch.cat(rotated_parts, dim=-1) + a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half])) + b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:])) + return torch.cat([a, b], dim=-1) class ClippedLinear(nn.Module): @@ -622,10 +609,8 @@ class Gemma4PatchEmbedder(nn.Module): hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype)) clamped_positions = pixel_position_ids.clamp(min=0) - one_hot = torch.nn.functional.one_hot(clamped_positions, num_classes=self.position_embedding_size) pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) - one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table) - position_embeddings = (one_hot @ pos_table).sum(dim=1) + position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]] # Zero out position embeddings for padding patches (matching HF) padding_positions = (pixel_position_ids == -1).all(dim=-1) From 5b470171dfdadfde204595d0df50f3ebdebbed68 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:55:25 +0300 Subject: [PATCH 15/18] Update nodes_textgen.py --- comfy_extras/nodes_textgen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 919f43d4a..1661a1011 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -32,7 +32,7 @@ class TextGenerate(io.ComfyNode): io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.Image.Input("image", optional=True), - io.Image.Input("video", optional=True, tooltip="Video frames as image batch (1 FPS recommended)."), + io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."), io.Audio.Input("audio", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), @@ -45,9 +45,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, video=video, audio=audio, skip_template=not use_default_template, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" @@ -164,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking, use_default_template) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio) class TextgenExtension(ComfyExtension): From 4257b8f35cb6058c123a40a4cfdbaee37cd260b6 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 27 Apr 2026 19:05:38 +0300 Subject: [PATCH 16/18] Use embed scale class instead of buffer Slight difference to HF, but technically more accurate and simpler code --- comfy/text_encoders/gemma4.py | 10 +++------- comfy/text_encoders/llama.py | 17 +++++++---------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 68d67ef05..1b70eadb5 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -8,7 +8,7 @@ from comfy import sd1_clip import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.rmsnorm import rms_norm -from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook +from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} @@ -253,9 +253,7 @@ class Gemma4Transformer(nn.Module): self.config = config fused = config.fused_rms_norm - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) - self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) - self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook) + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) self.layers = nn.ModuleList([ TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) @@ -278,9 +276,7 @@ class Gemma4Transformer(nn.Module): # Per-layer input mechanism self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: - self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype) - self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False) - self.embed_tokens_per_layer.register_forward_hook(_gemma_embed_scale_hook) + self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype) self.per_layer_model_projection = ops.Linear( config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 1c4fc26af..d1c43adb2 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -648,8 +648,11 @@ class TransformerBlockGemma2(nn.Module): return x, present_key_value -def _gemma_embed_scale_hook(module, input, output): - return (output.to(module._embed_scale.dtype) * module._embed_scale).to(output.dtype) +def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype): + class ScaledEmbedding(ops.Embedding): + def forward(self, input_ids, out_dtype=None): + return super().forward(input_ids, out_dtype=out_dtype) * scale + return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype) class Llama2_(nn.Module): @@ -658,18 +661,12 @@ class Llama2_(nn.Module): self.config = config self.vocab_size = config.vocab_size - self.embed_tokens = ops.Embedding( - config.vocab_size, - config.hidden_size, - device=device, - dtype=dtype - ) if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 - self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) - self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook) + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) else: transformer = TransformerBlock + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ transformer(config, index=i, device=device, dtype=dtype, ops=ops) From ee728a795f2f44bf3d03fd1bebdde23066dfac66 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:30:52 +0300 Subject: [PATCH 17/18] Default to fused rms_norm --- comfy/rmsnorm.py | 11 ++----- comfy/text_encoders/gemma4.py | 54 +++++++++++++++++------------------ comfy/text_encoders/llama.py | 5 ++-- 3 files changed, 31 insertions(+), 39 deletions(-) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index af0978341..e54be98d6 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -3,15 +3,8 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm -def rms_norm(x, weight=None, eps=1e-6, fused=True): - if not fused: # compatibility mode as torch native rms_norm results are slightly different - orig_dtype = x.dtype - normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5) - if weight is not None: - weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device) - normed = normed * weight - return normed.to(orig_dtype) - +# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding). +def rms_norm(x, weight=None, eps=1e-6): if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) else: diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 1b70eadb5..61ff42501 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -11,6 +11,12 @@ from comfy.rmsnorm import rms_norm from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding +# Intentional minor divergences from transformers -reference implementation: +# Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. +# RMSNorm uses torch fused F.rms_norm +# Input image and audio resizing/resampling slightly different numerically + + GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} @@ -45,7 +51,6 @@ class Gemma4Config: num_kv_shared_layers: int = 18 use_double_wide_mlp: bool = False stop_tokens = [1, 50, 106] - fused_rms_norm: bool = True # False: to match reference code's exact numerical behavior, which is much slower, so we default to True vision_config = GEMMA4_VISION_CONFIG audio_config = GEMMA4_AUDIO_CONFIG mm_tokens_per_image = 280 @@ -99,11 +104,10 @@ class Gemma4Attention(nn.Module): self.q_norm = None self.k_norm = None - fused = config.fused_rms_norm if config.q_norm == "gemma3": - self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.k_norm == "gemma3": - self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) def forward( self, @@ -132,7 +136,7 @@ class Gemma4Attention(nn.Module): xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) if self.k_norm is not None: xk = self.k_norm(xk) - xv = rms_norm(xv, fused=False) + xv = rms_norm(xv) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) xq = _apply_rotary_pos_emb(xq, freqs_cis) @@ -189,17 +193,16 @@ class TransformerBlockGemma4(nn.Module): mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) - fused = config.fused_rms_norm - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) - self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) else: self.layer_scalar = None @@ -251,7 +254,6 @@ class Gemma4Transformer(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config - fused = config.fused_rms_norm self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) @@ -260,7 +262,7 @@ class Gemma4Transformer(nn.Module): for i in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype, fused=fused) if config.final_norm else None + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None # Precompute RoPE inv_freq on CPU to match reference code's exact value rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) @@ -282,7 +284,7 @@ class Gemma4Transformer(nn.Module): bias=False, device=device, dtype=dtype) self.per_layer_projection_norm = RMSNorm( self.hidden_size_per_layer_input, eps=config.rms_norm_eps, - device=device, dtype=dtype, fused=fused) + device=device, dtype=dtype) def get_past_len(self, past_key_values): for kv in past_key_values: @@ -533,8 +535,8 @@ class Gemma4VisionAttention(nn.Module): self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) - self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) - self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) def forward(self, x, freqs, attention_mask=None): batch_size, seq_length, _ = x.shape @@ -545,7 +547,7 @@ class Gemma4VisionAttention(nn.Module): xq = self.q_norm(xq).transpose(1, 2) xk = self.k_norm(xk).transpose(1, 2) - xv = rms_norm(xv, fused=False) + xv = rms_norm(xv) xq = _apply_vision_2d_rope(xq, freqs) xk = _apply_vision_2d_rope(xk, freqs) @@ -561,7 +563,7 @@ class Gemma4VisionLayer(nn.Module): super().__init__() self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) - norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) hidden = config["hidden_size"] self.input_layernorm = RMSNorm(hidden, **norm_kwargs) self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) @@ -703,7 +705,7 @@ class Gemma4RMSNormProjector(nn.Module): self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) def forward(self, x): - return self.embedding_projection(rms_norm(x, fused=False)) + return self.embedding_projection(rms_norm(x)) class Gemma4MultiModalProjector(Gemma4RMSNormProjector): @@ -753,10 +755,10 @@ class Gemma4AudioFeedForward(nn.Module): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config.get("intermediate_size", hidden_size * 4) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) - self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.post_layer_scale = config.get("residual_weight", 0.5) def forward(self, x): @@ -897,12 +899,12 @@ class Gemma4AudioLConv1d(nn.Module): super().__init__() hidden_size = config["hidden_size"] conv_kernel_size = config.get("conv_kernel_size", 5) - self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) # Causal conv: left-pad only self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 - self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): @@ -925,7 +927,7 @@ class Gemma4AudioLayer(nn.Module): super().__init__() self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) - norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) hidden_size = config["hidden_size"] self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) @@ -1007,9 +1009,7 @@ class Gemma4_Tokenizer(): waveform = waveform.unsqueeze(0) audio = waveform.squeeze(0).float().numpy() if sample_rate != 16000: - # import librosa - # audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) - # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (still not full match) + # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) from scipy.signal import resample_poly, firwin from math import gcd g = gcd(sample_rate, 16000) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d1c43adb2..a34c41144 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -382,19 +382,18 @@ class Gemma3_12B_Config: stop_tokens = [1, 106] class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True): + def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.add = add - self.fused = fused def forward(self, x: torch.Tensor): w = self.weight if self.add: w = w + 1.0 - return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused) + return comfy.ldm.common_dit.rms_norm(x, w, self.eps) From de5e490e42f98a794aecfaff7c0e913ef8d2456c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:31:23 +0300 Subject: [PATCH 18/18] Update gemma4.py --- comfy/text_encoders/gemma4.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 61ff42501..f050061ed 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -12,9 +12,9 @@ from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _ma # Intentional minor divergences from transformers -reference implementation: -# Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. -# RMSNorm uses torch fused F.rms_norm -# Input image and audio resizing/resampling slightly different numerically +# - Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. +# - RMSNorm uses torch fused F.rms_norm, very slight numerical differences, but considerably faster +# - Input image and audio resizing/resampling slightly different numerically GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}