Cleanup, video fixes

This commit is contained in:
kijai 2026-04-07 12:37:29 +03:00
parent 93e8635110
commit 05eaceafa1

View File

@ -7,12 +7,13 @@ import math
from comfy import sd1_clip
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.rmsnorm import rms_norm
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook
GEMMA4_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "model_type": "gemma4_vision", "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_VISION_31B_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "gemma4_vision", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5, "gradient_clipping": 1e10, "hidden_act": "silu"}
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
@dataclass
class Gemma4Config:
@ -74,6 +75,7 @@ class Gemma4_31B_Config(Gemma4Config):
vision_config = GEMMA4_VISION_31B_CONFIG
# unfused RoPE as addcmul_ RoPE diverges from reference code
def _apply_rotary_pos_emb(x, freqs_cis):
cos, sin = freqs_cis[0], freqs_cis[1]
half = x.shape[-1] // 2
@ -82,7 +84,6 @@ def _apply_rotary_pos_emb(x, freqs_cis):
out[..., half:] += x[..., :half] * sin[..., half:]
return out
def _apply_rope_gemma(xq, xk, freqs_cis):
return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis)
@ -117,7 +118,6 @@ class Gemma4Attention(nn.Module):
past_key_value=None,
sliding_window=None,
shared_kv=None,
**kwargs,
):
batch_size, seq_length, _ = hidden_states.shape
@ -137,7 +137,7 @@ class Gemma4Attention(nn.Module):
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
if self.k_norm is not None:
xk = self.k_norm(xk)
xv = _parameterless_rms_norm(xv)
xv = rms_norm(xv, fused=False)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis)
@ -298,22 +298,17 @@ class Gemma4Transformer(nn.Module):
return kv[2]
return 0
def _freqs_from_inv(self, inv_freq, position_ids, dtype=None):
def _freqs_from_inv(self, inv_freq, position_ids, device, dtype):
"""Compute cos/sin from stored inv_freq"""
inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(position_ids.device)
inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device)
pos_exp = position_ids[:, None, :].float()
freqs = (inv_exp @ pos_exp).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().unsqueeze(1)
sin = emb.sin().unsqueeze(1)
result = (cos, sin)
if dtype is not None:
result = tuple(t.to(dtype) for t in result)
return result
return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype)
def compute_freqs_cis(self, position_ids, device, dtype=None):
global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, dtype)
sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, dtype)
global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype)
sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype)
return [global_freqs, sliding_freqs]
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None,
@ -358,7 +353,7 @@ class Gemma4Transformer(nn.Module):
else:
per_layer_inputs = per_layer_proj
# KV sharing: only last sliding (22) and last global (23) layers store KV for sharing
# KV sharing: later layers reuse KV from the last non-shared sliding/global layer
num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0)
first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers
shared_sliding_kv = None # KV from last non-shared sliding layer
@ -407,8 +402,8 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module):
self.num_layers = config.num_hidden_layers
self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype, device, operations)
self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype, device, operations)
self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations)
self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations)
def logits(self, x):
logits = super().logits(x)
@ -425,39 +420,26 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module):
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = embed["data"].movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W]
image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W]
max_soft_tokens = embed.get("max_soft_tokens", None)
vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens)
return self.multi_modal_projector(vision_out), None
if embed["type"] == "video":
frame_idx = embed.get("frame_idx", 0)
if not hasattr(self, '_video_cache') or self._video_cache is None:
# First frame: process all frames as a batch
frames = embed["data"].movedim(-1, 1) # [N, H, W, C] -> [N, C, H, W]
max_soft_tokens = embed.get("max_soft_tokens", None)
vision_out = self.vision_model(frames.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens)
projected = self.multi_modal_projector(vision_out) # [N, tokens_per_frame, hidden]
self._video_cache = projected
result = self._video_cache[frame_idx:frame_idx+1] # [1, tokens_per_frame, hidden]
if frame_idx == self._video_cache.shape[0] - 1:
self._video_cache = None # clear after last frame
return result, None
return None, None
class Gemma4AudioMixin:
"""Adds audio support to a Gemma4 model."""
def _init_audio(self, config, dtype, device, operations):
self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations)
self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations)
self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations)
self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations)
def preprocess_embed(self, embed, device):
result, extra = super().preprocess_embed(embed, device)
if result is not None:
return result, extra
if embed["type"] == "audio":
audio = embed["data"].to(device, dtype=torch.float32)
audio_mask = embed.get("mask", None)
audio = embed.pop("data").to(device, dtype=torch.float32)
audio_mask = embed.pop("mask", None)
if audio_mask is not None:
audio_mask = audio_mask.to(device)
audio_out = self.audio_model(audio, audio_mask=audio_mask)
@ -474,12 +456,6 @@ class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base):
# Vision Encoder
def _parameterless_rms_norm(x, eps=1e-6):
"""RMSNorm without learnable weight (used by Gemma4 v_norm and projectors)."""
mean_squared = x.float().pow(2).mean(-1, keepdim=True) + eps
return (x.float() * torch.pow(mean_squared, -0.5)).to(x.dtype)
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
"""Compute 2D RoPE for vision: separate frequencies for x and y dimensions.
@ -507,16 +483,16 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No
return cos, sin
def _apply_vision_2d_rope(x, cos, sin):
def _apply_vision_2d_rope(x, freqs):
"""Apply 2D RoPE (multidimensional) to vision query/key states.
Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently.
x: [batch, heads, seq, head_dim]
cos, sin: [batch, seq, head_dim]
freqs: (cos, sin) each [batch, seq, head_dim]
"""
cos = cos.unsqueeze(1) # [batch, 1, seq, head_dim]
sin = sin.unsqueeze(1)
cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim]
sin = freqs[1].unsqueeze(1)
def rotate_half(t):
t1 = t[..., :t.shape[-1]//2]
@ -541,9 +517,8 @@ class ClippedLinear(nn.Module):
Stores input_max/min and output_max/min as buffers loaded from checkpoint.
"""
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, operations=None):
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None):
super().__init__()
ops = operations or nn
self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype))
self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype))
@ -557,59 +532,51 @@ class ClippedLinear(nn.Module):
def forward(self, x):
x = x.clamp(min=self.input_min, max=self.input_max)
x = self.linear(x)
x = x.clamp(min=self.output_min, max=self.output_max)
return x
return x.clamp_(min=self.output_min, max=self.output_max)
class Gemma4VisionMLP(nn.Module):
"""SwiGLU MLP matching gate_proj/up_proj/down_proj structure."""
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden_size = config["hidden_size"]
intermediate_size = config["intermediate_size"]
self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations)
self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
def forward(self, x):
return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
class Gemma4VisionAttention(nn.Module):
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_attention_heads"]
self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads)
self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations)
self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations)
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations)
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations)
self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
def forward(self, x, cos_sin=None, attention_mask=None, **kwargs):
def forward(self, x, freqs, attention_mask=None):
batch_size, seq_length, _ = x.shape
xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xv = _parameterless_rms_norm(xv)
xq = self.q_norm(xq).transpose(1, 2)
xk = self.k_norm(xk).transpose(1, 2)
xv = rms_norm(xv, fused=False)
xq = xq.transpose(1, 2) # [B, H, S, D]
xk = xk.transpose(1, 2)
# Apply 2D RoPE
if cos_sin is not None:
cos, sin = cos_sin
xq = _apply_vision_2d_rope(xq, cos, sin)
xk = _apply_vision_2d_rope(xk, cos, sin)
xq = _apply_vision_2d_rope(xq, freqs)
xk = _apply_vision_2d_rope(xk, freqs)
xv = xv.to(xq.dtype).transpose(1, 2)
@ -618,10 +585,10 @@ class Gemma4VisionAttention(nn.Module):
class Gemma4VisionLayer(nn.Module):
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, operations=operations)
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, operations=operations)
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
hidden = config["hidden_size"]
self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
@ -629,10 +596,10 @@ class Gemma4VisionLayer(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
def forward(self, x, cos_sin=None, attention_mask=None):
def forward(self, x, freqs, attention_mask=None):
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask)
x = self.self_attn(x, freqs, attention_mask=attention_mask)
x = self.post_attention_layernorm(x)
x = residual + x
@ -646,14 +613,14 @@ class Gemma4VisionLayer(nn.Module):
class Gemma4PatchEmbedder(nn.Module):
"""Patch embedding with learned 2D position embeddings via one-hot lookup."""
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden_size = config["hidden_size"]
patch_size = config["patch_size"]
self.patch_size = patch_size
self.position_embedding_size = config.get("position_embedding_size", 10240)
self.input_proj = operations.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype)
self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype)
self.position_embedding_table = nn.Parameter(
torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype)
)
@ -680,16 +647,16 @@ class Gemma4PatchEmbedder(nn.Module):
class Gemma4VisionEncoderLayers(nn.Module):
"""Wrapper to produce state dict keys as encoder.layers.X.*"""
def __init__(self, config, dtype=None, device=None, operations=None):
def __init__(self, config, dtype=None, device=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([
Gemma4VisionLayer(config, device=device, dtype=dtype, operations=operations)
Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config["num_hidden_layers"])
])
class Gemma4VisionEncoder(nn.Module):
def __init__(self, config, dtype=None, device=None, operations=None):
def __init__(self, config, dtype=None, device=None, ops=None):
super().__init__()
self.config = config
self.hidden_size = config["hidden_size"]
@ -698,8 +665,8 @@ class Gemma4VisionEncoder(nn.Module):
self.pooling_kernel_size = config.get("pooling_kernel_size", 3)
self.root_hidden_size = self.hidden_size ** 0.5
self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, operations=operations)
self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, operations=operations)
self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops)
self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops)
def forward(self, pixel_values, max_soft_tokens=None):
"""
@ -720,7 +687,7 @@ class Gemma4VisionEncoder(nn.Module):
grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij')
position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1)
# Append zero-pixel padding with (-1,-1) positions (matching HF)
# Append zero-pixel padding with (-1,-1) positions
if n_padding > 0:
patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1)
position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1)
@ -729,12 +696,12 @@ class Gemma4VisionEncoder(nn.Module):
# Embed, encode, pool
x = self.patch_embedder(patches, position_ids)
cos_sin = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
cos_sin = tuple(t.to(x.dtype) for t in cos_sin)
freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
freqs = tuple(t.to(x.dtype) for t in freqs)
mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None
for layer in self.encoder.layers:
x = layer(x, cos_sin=cos_sin, attention_mask=mask)
x = layer(x, freqs, attention_mask=mask)
if n_padding > 0:
x = x.masked_fill(padding.unsqueeze(-1), 0.0)
@ -757,36 +724,36 @@ class Gemma4VisionEncoder(nn.Module):
class Gemma4RMSNormProjector(nn.Module):
"""Shared projector: parameterless RMSNorm → linear. Used for both vision and audio."""
def __init__(self, in_dim, out_dim, dtype=None, device=None, operations=None):
def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None):
super().__init__()
self.embedding_projection = operations.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.embedding_projection(_parameterless_rms_norm(x))
return self.embedding_projection(rms_norm(x, fused=False))
class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
def __init__(self, config, dtype=None, device=None, operations=None):
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, operations=operations)
def __init__(self, config, dtype=None, device=None, ops=None):
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
# Audio Encoder
class Gemma4AudioConvSubsampler(nn.Module):
"""2D convolution subsampling for audio features"""
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
eps = config.get("rms_norm_eps", 1e-6)
self.layer0 = nn.ModuleDict({
'conv': operations.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
'norm': operations.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
})
self.layer1 = nn.ModuleDict({
'conv': operations.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
'norm': operations.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
})
# proj_input_dim = (128 // 4) * 32 = 1024
self.input_proj_linear = operations.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype)
self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype)
def _conv_layer(self, x, layer, mask):
if mask is not None:
@ -807,26 +774,22 @@ class Gemma4AudioConvSubsampler(nn.Module):
class Gemma4AudioFeedForward(nn.Module):
"""Conformer feed-forward with gradient clipping and residual scaling."""
def __init__(self, config, device=None, dtype=None, operations=None):
"""Conformer feed-forward with residual scaling."""
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden_size = config["hidden_size"]
intermediate_size = config.get("intermediate_size", hidden_size * 4)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations)
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.post_layer_scale = config.get("residual_weight", 0.5)
self.gradient_clipping = config.get("gradient_clipping", 1e10)
def forward(self, x):
residual = x
gc = min(self.gradient_clipping, torch.finfo(x.dtype).max)
x = torch.clamp(x, -gc, gc)
x = self.pre_layer_norm(x)
x = torch.nn.functional.silu(self.ffw_layer_1(x))
x = self.ffw_layer_2(x)
x = torch.clamp(x, -gc, gc)
x = self.post_layer_norm(x)
x = x * self.post_layer_scale
return x + residual
@ -855,7 +818,7 @@ class Gemma4AudioRelPositionalEncoding(nn.Module):
class Gemma4AudioAttention(nn.Module):
"""Chunked block attention with relative position bias and softcap."""
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_attention_heads"]
@ -869,12 +832,12 @@ class Gemma4AudioAttention(nn.Module):
self.k_scale = math.log(1 + math.e) / math.log(2)
self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False)
self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype))
self.relative_k_proj = operations.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype)
self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype)
def _convert_to_block(self, x):
B, S, H, D = x.shape
@ -884,7 +847,6 @@ class Gemma4AudioAttention(nn.Module):
return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous()
def _extract_block_context(self, x):
B, S, H, D = x.shape
x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1))
x = x.unfold(1, self.context_size, self.chunk_size)
return torch.movedim(x, -1, 2).contiguous()
@ -907,7 +869,7 @@ class Gemma4AudioAttention(nn.Module):
if audio_mask is not None:
mask = mask & audio_mask[0, None, :].bool()
m = mask[None, None]
# Reshape to blocked 5D matching reference's _convert_4d_mask_to_blocked_5d
# Reshape to blocked 5D matching reference code
p = num_blocks * self.chunk_size - seq_len
m = torch.nn.functional.pad(m, (0, p, 0, p), value=False)
m = m.reshape(1, 1, num_blocks, self.chunk_size, -1)
@ -957,18 +919,17 @@ class Gemma4AudioAttention(nn.Module):
class Gemma4AudioLConv1d(nn.Module):
"""Lightweight convolution with standard GLU."""
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden_size = config["hidden_size"]
conv_kernel_size = config.get("conv_kernel_size", 5)
self.gradient_clipping = config.get("gradient_clipping", 1e10)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations)
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
# Causal conv: left-pad only
self.depthwise_conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations)
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
def forward(self, x):
residual = x
@ -978,8 +939,6 @@ class Gemma4AudioLConv1d(nn.Module):
x = x.transpose(1, 2)
x = torch.nn.functional.pad(x, (self.conv_left_pad, 0))
x = self.depthwise_conv1d(x).transpose(1, 2)
gc = min(self.gradient_clipping, torch.finfo(x.dtype).max)
x = torch.clamp(x, -gc, gc)
x = self.conv_norm(x)
x = torch.nn.functional.silu(x)
x = self.linear_end(x)
@ -988,54 +947,49 @@ class Gemma4AudioLConv1d(nn.Module):
class Gemma4AudioLayer(nn.Module):
"""Conformer block: FFN1 -> Attention -> LConv -> FFN2."""
def __init__(self, config, device=None, dtype=None, operations=None):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.gradient_clipping = config.get("gradient_clipping", 1e10)
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations)
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, operations=operations)
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
hidden_size = config["hidden_size"]
self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, operations=operations)
self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations)
self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops)
self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
self.norm_out = RMSNorm(hidden_size, **norm_kwargs)
def forward(self, x, position_embeddings=None, attn_mask=None):
gc = min(self.gradient_clipping, torch.finfo(x.dtype).max)
x = self.feed_forward1(x)
residual = x
x = torch.clamp(x, -gc, gc)
x = self.norm_pre_attn(x)
x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask)
x = torch.clamp(x, -gc, gc)
x = self.norm_post_attn(x)
x = x + residual
x = self.lconv1d(x)
x = self.feed_forward2(x)
x = torch.clamp(x, -gc, gc)
x = self.norm_out(x)
return x
class Gemma4AudioEncoder(nn.Module):
def __init__(self, config, dtype=None, device=None, operations=None):
def __init__(self, config, dtype=None, device=None, ops=None):
super().__init__()
self.hidden_size = config["hidden_size"]
self.output_proj_dims = config.get("output_proj_dims", 1536)
self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, operations=operations)
self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops)
self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Gemma4AudioLayer(config, device=device, dtype=dtype, operations=operations)
Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config["num_hidden_layers"])
])
self.output_proj = operations.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype)
self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype)
def forward(self, audio_features, audio_mask=None):
x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask)
@ -1054,8 +1008,8 @@ class Gemma4AudioEncoder(nn.Module):
class Gemma4AudioProjector(Gemma4RMSNormProjector):
def __init__(self, config, dtype=None, device=None, operations=None):
super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, operations=operations)
def __init__(self, config, dtype=None, device=None, ops=None):
super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops)
# Tokenizer and Wrappers
@ -1131,8 +1085,8 @@ class Gemma4_Tokenizer():
# Process audio
audio_features = []
if audio is not None:
waveform = audio["waveform"].squeeze(0) if isinstance(audio, dict) else audio
sample_rate = audio.get("sample_rate", 16000) if isinstance(audio, dict) else 16000
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
@ -1142,6 +1096,18 @@ class Gemma4_Tokenizer():
images = []
if source is not None:
samples = source.movedim(-1, 1) # [B, C, H, W]
num_frames = samples.shape[0]
# Subsample video to 1fps
if is_video:
fps = kwargs.get("fps", 24)
step = max(1, round(fps))
indices = list(range(0, num_frames, step))
if len(indices) == 0:
indices = [0]
samples = samples[indices]
num_frames = len(indices)
h, w = samples.shape[2], samples.shape[3]
patch_size = 16
pooling_k = 3
@ -1154,8 +1120,8 @@ class Gemma4_Tokenizer():
target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
import torchvision.transforms.functional as TVF
for i in range(samples.shape[0]):
# recaling to match reference code
for i in range(num_frames):
# rescaling to match reference code
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
if target_h != h or target_w != w:
s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True)
@ -1176,11 +1142,9 @@ class Gemma4_Tokenizer():
media = ""
if len(images) > 0:
if is_video:
fps = kwargs.get("fps", 24)
media += "\n\n"
for i in range(len(images)):
seconds = i / fps
ts = f"{int(seconds // 60):02d}:{int(seconds % 60):02d}"
ts = f"{int(i // 60):02d}:{int(i % 60):02d}"
sep = "" if i == 0 else " "
media += f"{sep}{ts} <|image><|video|><image|>"
media += "\n\n"
@ -1221,16 +1185,10 @@ class Gemma4_Tokenizer():
i += 1
if len(images) > 0:
if is_video:
# Video: batch all frames into one embed dict, each placeholder gets its frame's tokens
all_pixels = torch.cat([img["pixels"] for img in images], dim=0) # [N, H, W, C]
img_embeds = [{"type": "video", "data": all_pixels, "max_soft_tokens": images[0]["max_soft_tokens"], "frame_idx": i} for i in range(len(images))]
for r in text_tokens:
_replace_placeholders(r, 258884, img_embeds)
else:
img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images]
for r in text_tokens:
_replace_placeholders(r, 258880, img_embeds)
img_token_id = 258884 if is_video else 258880
img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images]
for r in text_tokens:
_replace_placeholders(r, img_token_id, img_embeds)
if len(audio_features) > 0:
aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features]