diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index ee4c672f2..9fac8c66a 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -7,12 +7,13 @@ import math from comfy import sd1_clip import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.rmsnorm import rms_norm from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook -GEMMA4_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "model_type": "gemma4_vision", "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} -GEMMA4_VISION_31B_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "gemma4_vision", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} -GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5, "gradient_clipping": 1e10, "hidden_act": "silu"} +GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} @dataclass class Gemma4Config: @@ -74,6 +75,7 @@ class Gemma4_31B_Config(Gemma4Config): vision_config = GEMMA4_VISION_31B_CONFIG +# unfused RoPE as addcmul_ RoPE diverges from reference code def _apply_rotary_pos_emb(x, freqs_cis): cos, sin = freqs_cis[0], freqs_cis[1] half = x.shape[-1] // 2 @@ -82,7 +84,6 @@ def _apply_rotary_pos_emb(x, freqs_cis): out[..., half:] += x[..., :half] * sin[..., half:] return out - def _apply_rope_gemma(xq, xk, freqs_cis): return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis) @@ -117,7 +118,6 @@ class Gemma4Attention(nn.Module): past_key_value=None, sliding_window=None, shared_kv=None, - **kwargs, ): batch_size, seq_length, _ = hidden_states.shape @@ -137,7 +137,7 @@ class Gemma4Attention(nn.Module): xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) if self.k_norm is not None: xk = self.k_norm(xk) - xv = _parameterless_rms_norm(xv) + xv = rms_norm(xv, fused=False) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis) @@ -298,22 +298,17 @@ class Gemma4Transformer(nn.Module): return kv[2] return 0 - def _freqs_from_inv(self, inv_freq, position_ids, dtype=None): + def _freqs_from_inv(self, inv_freq, position_ids, device, dtype): """Compute cos/sin from stored inv_freq""" - inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(position_ids.device) + inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device) pos_exp = position_ids[:, None, :].float() freqs = (inv_exp @ pos_exp).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().unsqueeze(1) - sin = emb.sin().unsqueeze(1) - result = (cos, sin) - if dtype is not None: - result = tuple(t.to(dtype) for t in result) - return result + return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype) def compute_freqs_cis(self, position_ids, device, dtype=None): - global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, dtype) - sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, dtype) + global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype) + sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype) return [global_freqs, sliding_freqs] def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, @@ -358,7 +353,7 @@ class Gemma4Transformer(nn.Module): else: per_layer_inputs = per_layer_proj - # KV sharing: only last sliding (22) and last global (23) layers store KV for sharing + # KV sharing: later layers reuse KV from the last non-shared sliding/global layer num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0) first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers shared_sliding_kv = None # KV from last non-shared sliding layer @@ -407,8 +402,8 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): self.num_layers = config.num_hidden_layers self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype - self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype, device, operations) - self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype, device, operations) + self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations) + self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations) def logits(self, x): logits = super().logits(x) @@ -425,39 +420,26 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): def preprocess_embed(self, embed, device): if embed["type"] == "image": - image = embed["data"].movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] + image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] max_soft_tokens = embed.get("max_soft_tokens", None) vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) return self.multi_modal_projector(vision_out), None - if embed["type"] == "video": - frame_idx = embed.get("frame_idx", 0) - if not hasattr(self, '_video_cache') or self._video_cache is None: - # First frame: process all frames as a batch - frames = embed["data"].movedim(-1, 1) # [N, H, W, C] -> [N, C, H, W] - max_soft_tokens = embed.get("max_soft_tokens", None) - vision_out = self.vision_model(frames.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) - projected = self.multi_modal_projector(vision_out) # [N, tokens_per_frame, hidden] - self._video_cache = projected - result = self._video_cache[frame_idx:frame_idx+1] # [1, tokens_per_frame, hidden] - if frame_idx == self._video_cache.shape[0] - 1: - self._video_cache = None # clear after last frame - return result, None return None, None class Gemma4AudioMixin: """Adds audio support to a Gemma4 model.""" def _init_audio(self, config, dtype, device, operations): - self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations) - self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations) + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations) def preprocess_embed(self, embed, device): result, extra = super().preprocess_embed(embed, device) if result is not None: return result, extra if embed["type"] == "audio": - audio = embed["data"].to(device, dtype=torch.float32) - audio_mask = embed.get("mask", None) + audio = embed.pop("data").to(device, dtype=torch.float32) + audio_mask = embed.pop("mask", None) if audio_mask is not None: audio_mask = audio_mask.to(device) audio_out = self.audio_model(audio, audio_mask=audio_mask) @@ -474,12 +456,6 @@ class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base): # Vision Encoder -def _parameterless_rms_norm(x, eps=1e-6): - """RMSNorm without learnable weight (used by Gemma4 v_norm and projectors).""" - mean_squared = x.float().pow(2).mean(-1, keepdim=True) + eps - return (x.float() * torch.pow(mean_squared, -0.5)).to(x.dtype) - - def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): """Compute 2D RoPE for vision: separate frequencies for x and y dimensions. @@ -507,16 +483,16 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No return cos, sin -def _apply_vision_2d_rope(x, cos, sin): +def _apply_vision_2d_rope(x, freqs): """Apply 2D RoPE (multidimensional) to vision query/key states. Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently. x: [batch, heads, seq, head_dim] - cos, sin: [batch, seq, head_dim] + freqs: (cos, sin) each [batch, seq, head_dim] """ - cos = cos.unsqueeze(1) # [batch, 1, seq, head_dim] - sin = sin.unsqueeze(1) + cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] + sin = freqs[1].unsqueeze(1) def rotate_half(t): t1 = t[..., :t.shape[-1]//2] @@ -541,9 +517,8 @@ class ClippedLinear(nn.Module): Stores input_max/min and output_max/min as buffers loaded from checkpoint. """ - def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, operations=None): + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None): super().__init__() - ops = operations or nn self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype)) self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) @@ -557,59 +532,51 @@ class ClippedLinear(nn.Module): def forward(self, x): x = x.clamp(min=self.input_min, max=self.input_max) x = self.linear(x) - x = x.clamp(min=self.output_min, max=self.output_max) - return x + return x.clamp_(min=self.output_min, max=self.output_max) class Gemma4VisionMLP(nn.Module): """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config["intermediate_size"] - self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) class Gemma4VisionAttention(nn.Module): - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) - - self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations) - self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - def forward(self, x, cos_sin=None, attention_mask=None, **kwargs): + def forward(self, x, freqs, attention_mask=None): batch_size, seq_length, _ = x.shape xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) - xq = self.q_norm(xq) - xk = self.k_norm(xk) - xv = _parameterless_rms_norm(xv) + xq = self.q_norm(xq).transpose(1, 2) + xk = self.k_norm(xk).transpose(1, 2) + xv = rms_norm(xv, fused=False) - xq = xq.transpose(1, 2) # [B, H, S, D] - xk = xk.transpose(1, 2) - - # Apply 2D RoPE - if cos_sin is not None: - cos, sin = cos_sin - xq = _apply_vision_2d_rope(xq, cos, sin) - xk = _apply_vision_2d_rope(xk, cos, sin) + xq = _apply_vision_2d_rope(xq, freqs) + xk = _apply_vision_2d_rope(xk, freqs) xv = xv.to(xq.dtype).transpose(1, 2) @@ -618,10 +585,10 @@ class Gemma4VisionAttention(nn.Module): class Gemma4VisionLayer(nn.Module): - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() - self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, operations=operations) - self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, operations=operations) + self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) hidden = config["hidden_size"] self.input_layernorm = RMSNorm(hidden, **norm_kwargs) @@ -629,10 +596,10 @@ class Gemma4VisionLayer(nn.Module): self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) - def forward(self, x, cos_sin=None, attention_mask=None): + def forward(self, x, freqs, attention_mask=None): residual = x x = self.input_layernorm(x) - x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask) + x = self.self_attn(x, freqs, attention_mask=attention_mask) x = self.post_attention_layernorm(x) x = residual + x @@ -646,14 +613,14 @@ class Gemma4VisionLayer(nn.Module): class Gemma4PatchEmbedder(nn.Module): """Patch embedding with learned 2D position embeddings via one-hot lookup.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] patch_size = config["patch_size"] self.patch_size = patch_size self.position_embedding_size = config.get("position_embedding_size", 10240) - self.input_proj = operations.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) + self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) self.position_embedding_table = nn.Parameter( torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) ) @@ -680,16 +647,16 @@ class Gemma4PatchEmbedder(nn.Module): class Gemma4VisionEncoderLayers(nn.Module): """Wrapper to produce state dict keys as encoder.layers.X.*""" - def __init__(self, config, dtype=None, device=None, operations=None): + def __init__(self, config, dtype=None, device=None, ops=None): super().__init__() self.layers = nn.ModuleList([ - Gemma4VisionLayer(config, device=device, dtype=dtype, operations=operations) + Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops) for _ in range(config["num_hidden_layers"]) ]) class Gemma4VisionEncoder(nn.Module): - def __init__(self, config, dtype=None, device=None, operations=None): + def __init__(self, config, dtype=None, device=None, ops=None): super().__init__() self.config = config self.hidden_size = config["hidden_size"] @@ -698,8 +665,8 @@ class Gemma4VisionEncoder(nn.Module): self.pooling_kernel_size = config.get("pooling_kernel_size", 3) self.root_hidden_size = self.hidden_size ** 0.5 - self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, operations=operations) - self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, operations=operations) + self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops) + self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops) def forward(self, pixel_values, max_soft_tokens=None): """ @@ -720,7 +687,7 @@ class Gemma4VisionEncoder(nn.Module): grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij') position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1) - # Append zero-pixel padding with (-1,-1) positions (matching HF) + # Append zero-pixel padding with (-1,-1) positions if n_padding > 0: patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1) position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1) @@ -729,12 +696,12 @@ class Gemma4VisionEncoder(nn.Module): # Embed, encode, pool x = self.patch_embedder(patches, position_ids) - cos_sin = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) - cos_sin = tuple(t.to(x.dtype) for t in cos_sin) + freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) + freqs = tuple(t.to(x.dtype) for t in freqs) mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None for layer in self.encoder.layers: - x = layer(x, cos_sin=cos_sin, attention_mask=mask) + x = layer(x, freqs, attention_mask=mask) if n_padding > 0: x = x.masked_fill(padding.unsqueeze(-1), 0.0) @@ -757,36 +724,36 @@ class Gemma4VisionEncoder(nn.Module): class Gemma4RMSNormProjector(nn.Module): """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio.""" - def __init__(self, in_dim, out_dim, dtype=None, device=None, operations=None): + def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None): super().__init__() - self.embedding_projection = operations.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) + self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) def forward(self, x): - return self.embedding_projection(_parameterless_rms_norm(x)) + return self.embedding_projection(rms_norm(x, fused=False)) class Gemma4MultiModalProjector(Gemma4RMSNormProjector): - def __init__(self, config, dtype=None, device=None, operations=None): - super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, operations=operations) + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) # Audio Encoder class Gemma4AudioConvSubsampler(nn.Module): """2D convolution subsampling for audio features""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() eps = config.get("rms_norm_eps", 1e-6) self.layer0 = nn.ModuleDict({ - 'conv': operations.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), - 'norm': operations.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), }) self.layer1 = nn.ModuleDict({ - 'conv': operations.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), - 'norm': operations.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), }) # proj_input_dim = (128 // 4) * 32 = 1024 - self.input_proj_linear = operations.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) + self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) def _conv_layer(self, x, layer, mask): if mask is not None: @@ -807,26 +774,22 @@ class Gemma4AudioConvSubsampler(nn.Module): class Gemma4AudioFeedForward(nn.Module): - """Conformer feed-forward with gradient clipping and residual scaling.""" - def __init__(self, config, device=None, dtype=None, operations=None): + """Conformer feed-forward with residual scaling.""" + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] intermediate_size = config.get("intermediate_size", hidden_size * 4) self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) - self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) self.post_layer_scale = config.get("residual_weight", 0.5) - self.gradient_clipping = config.get("gradient_clipping", 1e10) def forward(self, x): residual = x - gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) - x = torch.clamp(x, -gc, gc) x = self.pre_layer_norm(x) x = torch.nn.functional.silu(self.ffw_layer_1(x)) x = self.ffw_layer_2(x) - x = torch.clamp(x, -gc, gc) x = self.post_layer_norm(x) x = x * self.post_layer_scale return x + residual @@ -855,7 +818,7 @@ class Gemma4AudioRelPositionalEncoding(nn.Module): class Gemma4AudioAttention(nn.Module): """Chunked block attention with relative position bias and softcap.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] @@ -869,12 +832,12 @@ class Gemma4AudioAttention(nn.Module): self.k_scale = math.log(1 + math.e) / math.log(2) self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False) - self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) - self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations) + self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) - self.relative_k_proj = operations.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) + self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) def _convert_to_block(self, x): B, S, H, D = x.shape @@ -884,7 +847,6 @@ class Gemma4AudioAttention(nn.Module): return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous() def _extract_block_context(self, x): - B, S, H, D = x.shape x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)) x = x.unfold(1, self.context_size, self.chunk_size) return torch.movedim(x, -1, 2).contiguous() @@ -907,7 +869,7 @@ class Gemma4AudioAttention(nn.Module): if audio_mask is not None: mask = mask & audio_mask[0, None, :].bool() m = mask[None, None] - # Reshape to blocked 5D matching reference's _convert_4d_mask_to_blocked_5d + # Reshape to blocked 5D matching reference code p = num_blocks * self.chunk_size - seq_len m = torch.nn.functional.pad(m, (0, p, 0, p), value=False) m = m.reshape(1, 1, num_blocks, self.chunk_size, -1) @@ -957,18 +919,17 @@ class Gemma4AudioAttention(nn.Module): class Gemma4AudioLConv1d(nn.Module): """Lightweight convolution with standard GLU.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() hidden_size = config["hidden_size"] conv_kernel_size = config.get("conv_kernel_size", 5) - self.gradient_clipping = config.get("gradient_clipping", 1e10) self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations) + self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) # Causal conv: left-pad only - self.depthwise_conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) + self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) - self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations) + self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) def forward(self, x): residual = x @@ -978,8 +939,6 @@ class Gemma4AudioLConv1d(nn.Module): x = x.transpose(1, 2) x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) x = self.depthwise_conv1d(x).transpose(1, 2) - gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) - x = torch.clamp(x, -gc, gc) x = self.conv_norm(x) x = torch.nn.functional.silu(x) x = self.linear_end(x) @@ -988,54 +947,49 @@ class Gemma4AudioLConv1d(nn.Module): class Gemma4AudioLayer(nn.Module): """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" - def __init__(self, config, device=None, dtype=None, operations=None): + def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() - self.gradient_clipping = config.get("gradient_clipping", 1e10) - self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) - self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, operations=operations) + self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False) hidden_size = config["hidden_size"] self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) - self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, operations=operations) - self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations) + self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops) + self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) self.norm_out = RMSNorm(hidden_size, **norm_kwargs) def forward(self, x, position_embeddings=None, attn_mask=None): - gc = min(self.gradient_clipping, torch.finfo(x.dtype).max) x = self.feed_forward1(x) residual = x - x = torch.clamp(x, -gc, gc) x = self.norm_pre_attn(x) x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask) - x = torch.clamp(x, -gc, gc) x = self.norm_post_attn(x) x = x + residual x = self.lconv1d(x) x = self.feed_forward2(x) - x = torch.clamp(x, -gc, gc) x = self.norm_out(x) return x class Gemma4AudioEncoder(nn.Module): - def __init__(self, config, dtype=None, device=None, operations=None): + def __init__(self, config, dtype=None, device=None, ops=None): super().__init__() self.hidden_size = config["hidden_size"] self.output_proj_dims = config.get("output_proj_dims", 1536) - self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, operations=operations) + self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops) self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype) self.layers = nn.ModuleList([ - Gemma4AudioLayer(config, device=device, dtype=dtype, operations=operations) + Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops) for _ in range(config["num_hidden_layers"]) ]) - self.output_proj = operations.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) + self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) def forward(self, audio_features, audio_mask=None): x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask) @@ -1054,8 +1008,8 @@ class Gemma4AudioEncoder(nn.Module): class Gemma4AudioProjector(Gemma4RMSNormProjector): - def __init__(self, config, dtype=None, device=None, operations=None): - super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, operations=operations) + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops) # Tokenizer and Wrappers @@ -1131,8 +1085,8 @@ class Gemma4_Tokenizer(): # Process audio audio_features = [] if audio is not None: - waveform = audio["waveform"].squeeze(0) if isinstance(audio, dict) else audio - sample_rate = audio.get("sample_rate", 16000) if isinstance(audio, dict) else 16000 + waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio + sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) @@ -1142,6 +1096,18 @@ class Gemma4_Tokenizer(): images = [] if source is not None: samples = source.movedim(-1, 1) # [B, C, H, W] + num_frames = samples.shape[0] + + # Subsample video to 1fps + if is_video: + fps = kwargs.get("fps", 24) + step = max(1, round(fps)) + indices = list(range(0, num_frames, step)) + if len(indices) == 0: + indices = [0] + samples = samples[indices] + num_frames = len(indices) + h, w = samples.shape[2], samples.shape[3] patch_size = 16 pooling_k = 3 @@ -1154,8 +1120,8 @@ class Gemma4_Tokenizer(): target_w = max(int(factor * w // side_mult) * side_mult, side_mult) import torchvision.transforms.functional as TVF - for i in range(samples.shape[0]): - # recaling to match reference code + for i in range(num_frames): + # rescaling to match reference code s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 if target_h != h or target_w != w: s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True) @@ -1176,11 +1142,9 @@ class Gemma4_Tokenizer(): media = "" if len(images) > 0: if is_video: - fps = kwargs.get("fps", 24) media += "\n\n" for i in range(len(images)): - seconds = i / fps - ts = f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" + ts = f"{int(i // 60):02d}:{int(i % 60):02d}" sep = "" if i == 0 else " " media += f"{sep}{ts} <|image><|video|>" media += "\n\n" @@ -1221,16 +1185,10 @@ class Gemma4_Tokenizer(): i += 1 if len(images) > 0: - if is_video: - # Video: batch all frames into one embed dict, each placeholder gets its frame's tokens - all_pixels = torch.cat([img["pixels"] for img in images], dim=0) # [N, H, W, C] - img_embeds = [{"type": "video", "data": all_pixels, "max_soft_tokens": images[0]["max_soft_tokens"], "frame_idx": i} for i in range(len(images))] - for r in text_tokens: - _replace_placeholders(r, 258884, img_embeds) - else: - img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] - for r in text_tokens: - _replace_placeholders(r, 258880, img_embeds) + img_token_id = 258884 if is_video else 258880 + img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] + for r in text_tokens: + _replace_placeholders(r, img_token_id, img_embeds) if len(audio_features) > 0: aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features]