mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Cleanup, video fixes
This commit is contained in:
parent
93e8635110
commit
05eaceafa1
@ -7,12 +7,13 @@ import math
|
||||
from comfy import sd1_clip
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.rmsnorm import rms_norm
|
||||
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook
|
||||
|
||||
|
||||
GEMMA4_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "model_type": "gemma4_vision", "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_VISION_31B_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "gemma4_vision", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5, "gradient_clipping": 1e10, "hidden_act": "silu"}
|
||||
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
|
||||
|
||||
@dataclass
|
||||
class Gemma4Config:
|
||||
@ -74,6 +75,7 @@ class Gemma4_31B_Config(Gemma4Config):
|
||||
vision_config = GEMMA4_VISION_31B_CONFIG
|
||||
|
||||
|
||||
# unfused RoPE as addcmul_ RoPE diverges from reference code
|
||||
def _apply_rotary_pos_emb(x, freqs_cis):
|
||||
cos, sin = freqs_cis[0], freqs_cis[1]
|
||||
half = x.shape[-1] // 2
|
||||
@ -82,7 +84,6 @@ def _apply_rotary_pos_emb(x, freqs_cis):
|
||||
out[..., half:] += x[..., :half] * sin[..., half:]
|
||||
return out
|
||||
|
||||
|
||||
def _apply_rope_gemma(xq, xk, freqs_cis):
|
||||
return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis)
|
||||
|
||||
@ -117,7 +118,6 @@ class Gemma4Attention(nn.Module):
|
||||
past_key_value=None,
|
||||
sliding_window=None,
|
||||
shared_kv=None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
@ -137,7 +137,7 @@ class Gemma4Attention(nn.Module):
|
||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||
if self.k_norm is not None:
|
||||
xk = self.k_norm(xk)
|
||||
xv = _parameterless_rms_norm(xv)
|
||||
xv = rms_norm(xv, fused=False)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis)
|
||||
@ -298,22 +298,17 @@ class Gemma4Transformer(nn.Module):
|
||||
return kv[2]
|
||||
return 0
|
||||
|
||||
def _freqs_from_inv(self, inv_freq, position_ids, dtype=None):
|
||||
def _freqs_from_inv(self, inv_freq, position_ids, device, dtype):
|
||||
"""Compute cos/sin from stored inv_freq"""
|
||||
inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(position_ids.device)
|
||||
inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device)
|
||||
pos_exp = position_ids[:, None, :].float()
|
||||
freqs = (inv_exp @ pos_exp).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().unsqueeze(1)
|
||||
sin = emb.sin().unsqueeze(1)
|
||||
result = (cos, sin)
|
||||
if dtype is not None:
|
||||
result = tuple(t.to(dtype) for t in result)
|
||||
return result
|
||||
return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype)
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device, dtype=None):
|
||||
global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, dtype)
|
||||
sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, dtype)
|
||||
global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype)
|
||||
sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype)
|
||||
return [global_freqs, sliding_freqs]
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None,
|
||||
@ -358,7 +353,7 @@ class Gemma4Transformer(nn.Module):
|
||||
else:
|
||||
per_layer_inputs = per_layer_proj
|
||||
|
||||
# KV sharing: only last sliding (22) and last global (23) layers store KV for sharing
|
||||
# KV sharing: later layers reuse KV from the last non-shared sliding/global layer
|
||||
num_kv_shared = getattr(self.config, 'num_kv_shared_layers', 0)
|
||||
first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers
|
||||
shared_sliding_kv = None # KV from last non-shared sliding layer
|
||||
@ -407,8 +402,8 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype, device, operations)
|
||||
self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype, device, operations)
|
||||
self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations)
|
||||
self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations)
|
||||
|
||||
def logits(self, x):
|
||||
logits = super().logits(x)
|
||||
@ -425,39 +420,26 @@ class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image = embed["data"].movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W]
|
||||
image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W]
|
||||
max_soft_tokens = embed.get("max_soft_tokens", None)
|
||||
vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens)
|
||||
return self.multi_modal_projector(vision_out), None
|
||||
if embed["type"] == "video":
|
||||
frame_idx = embed.get("frame_idx", 0)
|
||||
if not hasattr(self, '_video_cache') or self._video_cache is None:
|
||||
# First frame: process all frames as a batch
|
||||
frames = embed["data"].movedim(-1, 1) # [N, H, W, C] -> [N, C, H, W]
|
||||
max_soft_tokens = embed.get("max_soft_tokens", None)
|
||||
vision_out = self.vision_model(frames.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens)
|
||||
projected = self.multi_modal_projector(vision_out) # [N, tokens_per_frame, hidden]
|
||||
self._video_cache = projected
|
||||
result = self._video_cache[frame_idx:frame_idx+1] # [1, tokens_per_frame, hidden]
|
||||
if frame_idx == self._video_cache.shape[0] - 1:
|
||||
self._video_cache = None # clear after last frame
|
||||
return result, None
|
||||
return None, None
|
||||
|
||||
|
||||
class Gemma4AudioMixin:
|
||||
"""Adds audio support to a Gemma4 model."""
|
||||
def _init_audio(self, config, dtype, device, operations):
|
||||
self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype, device, operations)
|
||||
self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype, device, operations)
|
||||
self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations)
|
||||
self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations)
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
result, extra = super().preprocess_embed(embed, device)
|
||||
if result is not None:
|
||||
return result, extra
|
||||
if embed["type"] == "audio":
|
||||
audio = embed["data"].to(device, dtype=torch.float32)
|
||||
audio_mask = embed.get("mask", None)
|
||||
audio = embed.pop("data").to(device, dtype=torch.float32)
|
||||
audio_mask = embed.pop("mask", None)
|
||||
if audio_mask is not None:
|
||||
audio_mask = audio_mask.to(device)
|
||||
audio_out = self.audio_model(audio, audio_mask=audio_mask)
|
||||
@ -474,12 +456,6 @@ class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base):
|
||||
|
||||
# Vision Encoder
|
||||
|
||||
def _parameterless_rms_norm(x, eps=1e-6):
|
||||
"""RMSNorm without learnable weight (used by Gemma4 v_norm and projectors)."""
|
||||
mean_squared = x.float().pow(2).mean(-1, keepdim=True) + eps
|
||||
return (x.float() * torch.pow(mean_squared, -0.5)).to(x.dtype)
|
||||
|
||||
|
||||
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
|
||||
"""Compute 2D RoPE for vision: separate frequencies for x and y dimensions.
|
||||
|
||||
@ -507,16 +483,16 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No
|
||||
return cos, sin
|
||||
|
||||
|
||||
def _apply_vision_2d_rope(x, cos, sin):
|
||||
def _apply_vision_2d_rope(x, freqs):
|
||||
"""Apply 2D RoPE (multidimensional) to vision query/key states.
|
||||
|
||||
Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently.
|
||||
|
||||
x: [batch, heads, seq, head_dim]
|
||||
cos, sin: [batch, seq, head_dim]
|
||||
freqs: (cos, sin) each [batch, seq, head_dim]
|
||||
"""
|
||||
cos = cos.unsqueeze(1) # [batch, 1, seq, head_dim]
|
||||
sin = sin.unsqueeze(1)
|
||||
cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim]
|
||||
sin = freqs[1].unsqueeze(1)
|
||||
|
||||
def rotate_half(t):
|
||||
t1 = t[..., :t.shape[-1]//2]
|
||||
@ -541,9 +517,8 @@ class ClippedLinear(nn.Module):
|
||||
|
||||
Stores input_max/min and output_max/min as buffers loaded from checkpoint.
|
||||
"""
|
||||
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, operations=None):
|
||||
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
ops = operations or nn
|
||||
self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype))
|
||||
self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype))
|
||||
@ -557,59 +532,51 @@ class ClippedLinear(nn.Module):
|
||||
def forward(self, x):
|
||||
x = x.clamp(min=self.input_min, max=self.input_max)
|
||||
x = self.linear(x)
|
||||
x = x.clamp(min=self.output_min, max=self.output_max)
|
||||
return x
|
||||
return x.clamp_(min=self.output_min, max=self.output_max)
|
||||
|
||||
|
||||
class Gemma4VisionMLP(nn.Module):
|
||||
"""SwiGLU MLP matching gate_proj/up_proj/down_proj structure."""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
hidden_size = config["hidden_size"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
|
||||
self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
|
||||
self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
|
||||
|
||||
|
||||
class Gemma4VisionAttention(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_heads = config["num_attention_heads"]
|
||||
self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads)
|
||||
|
||||
|
||||
self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
|
||||
self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
|
||||
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
|
||||
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
|
||||
def forward(self, x, cos_sin=None, attention_mask=None, **kwargs):
|
||||
def forward(self, x, freqs, attention_mask=None):
|
||||
batch_size, seq_length, _ = x.shape
|
||||
|
||||
xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
|
||||
xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
|
||||
xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
|
||||
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
xv = _parameterless_rms_norm(xv)
|
||||
xq = self.q_norm(xq).transpose(1, 2)
|
||||
xk = self.k_norm(xk).transpose(1, 2)
|
||||
xv = rms_norm(xv, fused=False)
|
||||
|
||||
xq = xq.transpose(1, 2) # [B, H, S, D]
|
||||
xk = xk.transpose(1, 2)
|
||||
|
||||
# Apply 2D RoPE
|
||||
if cos_sin is not None:
|
||||
cos, sin = cos_sin
|
||||
xq = _apply_vision_2d_rope(xq, cos, sin)
|
||||
xk = _apply_vision_2d_rope(xk, cos, sin)
|
||||
xq = _apply_vision_2d_rope(xq, freqs)
|
||||
xk = _apply_vision_2d_rope(xk, freqs)
|
||||
|
||||
xv = xv.to(xq.dtype).transpose(1, 2)
|
||||
|
||||
@ -618,10 +585,10 @@ class Gemma4VisionAttention(nn.Module):
|
||||
|
||||
|
||||
class Gemma4VisionLayer(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, operations=operations)
|
||||
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
|
||||
norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
hidden = config["hidden_size"]
|
||||
self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
|
||||
@ -629,10 +596,10 @@ class Gemma4VisionLayer(nn.Module):
|
||||
self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
|
||||
self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
|
||||
|
||||
def forward(self, x, cos_sin=None, attention_mask=None):
|
||||
def forward(self, x, freqs, attention_mask=None):
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(x, cos_sin=cos_sin, attention_mask=attention_mask)
|
||||
x = self.self_attn(x, freqs, attention_mask=attention_mask)
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
@ -646,14 +613,14 @@ class Gemma4VisionLayer(nn.Module):
|
||||
|
||||
class Gemma4PatchEmbedder(nn.Module):
|
||||
"""Patch embedding with learned 2D position embeddings via one-hot lookup."""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
hidden_size = config["hidden_size"]
|
||||
patch_size = config["patch_size"]
|
||||
self.patch_size = patch_size
|
||||
self.position_embedding_size = config.get("position_embedding_size", 10240)
|
||||
|
||||
self.input_proj = operations.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.position_embedding_table = nn.Parameter(
|
||||
torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype)
|
||||
)
|
||||
@ -680,16 +647,16 @@ class Gemma4PatchEmbedder(nn.Module):
|
||||
|
||||
class Gemma4VisionEncoderLayers(nn.Module):
|
||||
"""Wrapper to produce state dict keys as encoder.layers.X.*"""
|
||||
def __init__(self, config, dtype=None, device=None, operations=None):
|
||||
def __init__(self, config, dtype=None, device=None, ops=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
Gemma4VisionLayer(config, device=device, dtype=dtype, operations=operations)
|
||||
Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config["num_hidden_layers"])
|
||||
])
|
||||
|
||||
|
||||
class Gemma4VisionEncoder(nn.Module):
|
||||
def __init__(self, config, dtype=None, device=None, operations=None):
|
||||
def __init__(self, config, dtype=None, device=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config["hidden_size"]
|
||||
@ -698,8 +665,8 @@ class Gemma4VisionEncoder(nn.Module):
|
||||
self.pooling_kernel_size = config.get("pooling_kernel_size", 3)
|
||||
self.root_hidden_size = self.hidden_size ** 0.5
|
||||
|
||||
self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, operations=operations)
|
||||
self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, operations=operations)
|
||||
self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops)
|
||||
self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops)
|
||||
|
||||
def forward(self, pixel_values, max_soft_tokens=None):
|
||||
"""
|
||||
@ -720,7 +687,7 @@ class Gemma4VisionEncoder(nn.Module):
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij')
|
||||
position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
# Append zero-pixel padding with (-1,-1) positions (matching HF)
|
||||
# Append zero-pixel padding with (-1,-1) positions
|
||||
if n_padding > 0:
|
||||
patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1)
|
||||
position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1)
|
||||
@ -729,12 +696,12 @@ class Gemma4VisionEncoder(nn.Module):
|
||||
|
||||
# Embed, encode, pool
|
||||
x = self.patch_embedder(patches, position_ids)
|
||||
cos_sin = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
|
||||
cos_sin = tuple(t.to(x.dtype) for t in cos_sin)
|
||||
freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
|
||||
freqs = tuple(t.to(x.dtype) for t in freqs)
|
||||
mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None
|
||||
|
||||
for layer in self.encoder.layers:
|
||||
x = layer(x, cos_sin=cos_sin, attention_mask=mask)
|
||||
x = layer(x, freqs, attention_mask=mask)
|
||||
|
||||
if n_padding > 0:
|
||||
x = x.masked_fill(padding.unsqueeze(-1), 0.0)
|
||||
@ -757,36 +724,36 @@ class Gemma4VisionEncoder(nn.Module):
|
||||
|
||||
class Gemma4RMSNormProjector(nn.Module):
|
||||
"""Shared projector: parameterless RMSNorm → linear. Used for both vision and audio."""
|
||||
def __init__(self, in_dim, out_dim, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None):
|
||||
super().__init__()
|
||||
self.embedding_projection = operations.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
|
||||
self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embedding_projection(_parameterless_rms_norm(x))
|
||||
return self.embedding_projection(rms_norm(x, fused=False))
|
||||
|
||||
|
||||
class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
|
||||
def __init__(self, config, dtype=None, device=None, operations=None):
|
||||
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
def __init__(self, config, dtype=None, device=None, ops=None):
|
||||
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
|
||||
|
||||
|
||||
# Audio Encoder
|
||||
|
||||
class Gemma4AudioConvSubsampler(nn.Module):
|
||||
"""2D convolution subsampling for audio features"""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
eps = config.get("rms_norm_eps", 1e-6)
|
||||
self.layer0 = nn.ModuleDict({
|
||||
'conv': operations.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
|
||||
'norm': operations.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
|
||||
'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
|
||||
'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
|
||||
})
|
||||
self.layer1 = nn.ModuleDict({
|
||||
'conv': operations.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
|
||||
'norm': operations.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
|
||||
'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
|
||||
'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
|
||||
})
|
||||
# proj_input_dim = (128 // 4) * 32 = 1024
|
||||
self.input_proj_linear = operations.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype)
|
||||
self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype)
|
||||
|
||||
def _conv_layer(self, x, layer, mask):
|
||||
if mask is not None:
|
||||
@ -807,26 +774,22 @@ class Gemma4AudioConvSubsampler(nn.Module):
|
||||
|
||||
|
||||
class Gemma4AudioFeedForward(nn.Module):
|
||||
"""Conformer feed-forward with gradient clipping and residual scaling."""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
"""Conformer feed-forward with residual scaling."""
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
hidden_size = config["hidden_size"]
|
||||
intermediate_size = config.get("intermediate_size", hidden_size * 4)
|
||||
self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
|
||||
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
self.post_layer_scale = config.get("residual_weight", 0.5)
|
||||
self.gradient_clipping = config.get("gradient_clipping", 1e10)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
gc = min(self.gradient_clipping, torch.finfo(x.dtype).max)
|
||||
x = torch.clamp(x, -gc, gc)
|
||||
x = self.pre_layer_norm(x)
|
||||
x = torch.nn.functional.silu(self.ffw_layer_1(x))
|
||||
x = self.ffw_layer_2(x)
|
||||
x = torch.clamp(x, -gc, gc)
|
||||
x = self.post_layer_norm(x)
|
||||
x = x * self.post_layer_scale
|
||||
return x + residual
|
||||
@ -855,7 +818,7 @@ class Gemma4AudioRelPositionalEncoding(nn.Module):
|
||||
|
||||
class Gemma4AudioAttention(nn.Module):
|
||||
"""Chunked block attention with relative position bias and softcap."""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_heads = config["num_attention_heads"]
|
||||
@ -869,12 +832,12 @@ class Gemma4AudioAttention(nn.Module):
|
||||
self.k_scale = math.log(1 + math.e) / math.log(2)
|
||||
self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False)
|
||||
|
||||
self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype))
|
||||
self.relative_k_proj = operations.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def _convert_to_block(self, x):
|
||||
B, S, H, D = x.shape
|
||||
@ -884,7 +847,6 @@ class Gemma4AudioAttention(nn.Module):
|
||||
return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous()
|
||||
|
||||
def _extract_block_context(self, x):
|
||||
B, S, H, D = x.shape
|
||||
x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1))
|
||||
x = x.unfold(1, self.context_size, self.chunk_size)
|
||||
return torch.movedim(x, -1, 2).contiguous()
|
||||
@ -907,7 +869,7 @@ class Gemma4AudioAttention(nn.Module):
|
||||
if audio_mask is not None:
|
||||
mask = mask & audio_mask[0, None, :].bool()
|
||||
m = mask[None, None]
|
||||
# Reshape to blocked 5D matching reference's _convert_4d_mask_to_blocked_5d
|
||||
# Reshape to blocked 5D matching reference code
|
||||
p = num_blocks * self.chunk_size - seq_len
|
||||
m = torch.nn.functional.pad(m, (0, p, 0, p), value=False)
|
||||
m = m.reshape(1, 1, num_blocks, self.chunk_size, -1)
|
||||
@ -957,18 +919,17 @@ class Gemma4AudioAttention(nn.Module):
|
||||
|
||||
class Gemma4AudioLConv1d(nn.Module):
|
||||
"""Lightweight convolution with standard GLU."""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
hidden_size = config["hidden_size"]
|
||||
conv_kernel_size = config.get("conv_kernel_size", 5)
|
||||
self.gradient_clipping = config.get("gradient_clipping", 1e10)
|
||||
self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, operations=operations)
|
||||
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
|
||||
# Causal conv: left-pad only
|
||||
self.depthwise_conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
|
||||
self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
@ -978,8 +939,6 @@ class Gemma4AudioLConv1d(nn.Module):
|
||||
x = x.transpose(1, 2)
|
||||
x = torch.nn.functional.pad(x, (self.conv_left_pad, 0))
|
||||
x = self.depthwise_conv1d(x).transpose(1, 2)
|
||||
gc = min(self.gradient_clipping, torch.finfo(x.dtype).max)
|
||||
x = torch.clamp(x, -gc, gc)
|
||||
x = self.conv_norm(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
x = self.linear_end(x)
|
||||
@ -988,54 +947,49 @@ class Gemma4AudioLConv1d(nn.Module):
|
||||
|
||||
class Gemma4AudioLayer(nn.Module):
|
||||
"""Conformer block: FFN1 -> Attention -> LConv -> FFN2."""
|
||||
def __init__(self, config, device=None, dtype=None, operations=None):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.gradient_clipping = config.get("gradient_clipping", 1e10)
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations)
|
||||
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, operations=operations)
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
|
||||
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
|
||||
hidden_size = config["hidden_size"]
|
||||
self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
|
||||
self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
|
||||
self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, operations=operations)
|
||||
self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, operations=operations)
|
||||
self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops)
|
||||
self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
|
||||
self.norm_out = RMSNorm(hidden_size, **norm_kwargs)
|
||||
|
||||
def forward(self, x, position_embeddings=None, attn_mask=None):
|
||||
gc = min(self.gradient_clipping, torch.finfo(x.dtype).max)
|
||||
x = self.feed_forward1(x)
|
||||
|
||||
residual = x
|
||||
x = torch.clamp(x, -gc, gc)
|
||||
x = self.norm_pre_attn(x)
|
||||
x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask)
|
||||
x = torch.clamp(x, -gc, gc)
|
||||
x = self.norm_post_attn(x)
|
||||
x = x + residual
|
||||
|
||||
x = self.lconv1d(x)
|
||||
x = self.feed_forward2(x)
|
||||
|
||||
x = torch.clamp(x, -gc, gc)
|
||||
x = self.norm_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Gemma4AudioEncoder(nn.Module):
|
||||
def __init__(self, config, dtype=None, device=None, operations=None):
|
||||
def __init__(self, config, dtype=None, device=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.output_proj_dims = config.get("output_proj_dims", 1536)
|
||||
|
||||
self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, operations=operations)
|
||||
self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops)
|
||||
self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Gemma4AudioLayer(config, device=device, dtype=dtype, operations=operations)
|
||||
Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config["num_hidden_layers"])
|
||||
])
|
||||
|
||||
self.output_proj = operations.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype)
|
||||
self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, audio_features, audio_mask=None):
|
||||
x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask)
|
||||
@ -1054,8 +1008,8 @@ class Gemma4AudioEncoder(nn.Module):
|
||||
|
||||
|
||||
class Gemma4AudioProjector(Gemma4RMSNormProjector):
|
||||
def __init__(self, config, dtype=None, device=None, operations=None):
|
||||
super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, operations=operations)
|
||||
def __init__(self, config, dtype=None, device=None, ops=None):
|
||||
super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops)
|
||||
|
||||
|
||||
# Tokenizer and Wrappers
|
||||
@ -1131,8 +1085,8 @@ class Gemma4_Tokenizer():
|
||||
# Process audio
|
||||
audio_features = []
|
||||
if audio is not None:
|
||||
waveform = audio["waveform"].squeeze(0) if isinstance(audio, dict) else audio
|
||||
sample_rate = audio.get("sample_rate", 16000) if isinstance(audio, dict) else 16000
|
||||
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
|
||||
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
|
||||
mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
|
||||
audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
|
||||
|
||||
@ -1142,6 +1096,18 @@ class Gemma4_Tokenizer():
|
||||
images = []
|
||||
if source is not None:
|
||||
samples = source.movedim(-1, 1) # [B, C, H, W]
|
||||
num_frames = samples.shape[0]
|
||||
|
||||
# Subsample video to 1fps
|
||||
if is_video:
|
||||
fps = kwargs.get("fps", 24)
|
||||
step = max(1, round(fps))
|
||||
indices = list(range(0, num_frames, step))
|
||||
if len(indices) == 0:
|
||||
indices = [0]
|
||||
samples = samples[indices]
|
||||
num_frames = len(indices)
|
||||
|
||||
h, w = samples.shape[2], samples.shape[3]
|
||||
patch_size = 16
|
||||
pooling_k = 3
|
||||
@ -1154,8 +1120,8 @@ class Gemma4_Tokenizer():
|
||||
target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
|
||||
|
||||
import torchvision.transforms.functional as TVF
|
||||
for i in range(samples.shape[0]):
|
||||
# recaling to match reference code
|
||||
for i in range(num_frames):
|
||||
# rescaling to match reference code
|
||||
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
|
||||
if target_h != h or target_w != w:
|
||||
s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True)
|
||||
@ -1176,11 +1142,9 @@ class Gemma4_Tokenizer():
|
||||
media = ""
|
||||
if len(images) > 0:
|
||||
if is_video:
|
||||
fps = kwargs.get("fps", 24)
|
||||
media += "\n\n"
|
||||
for i in range(len(images)):
|
||||
seconds = i / fps
|
||||
ts = f"{int(seconds // 60):02d}:{int(seconds % 60):02d}"
|
||||
ts = f"{int(i // 60):02d}:{int(i % 60):02d}"
|
||||
sep = "" if i == 0 else " "
|
||||
media += f"{sep}{ts} <|image><|video|><image|>"
|
||||
media += "\n\n"
|
||||
@ -1221,16 +1185,10 @@ class Gemma4_Tokenizer():
|
||||
i += 1
|
||||
|
||||
if len(images) > 0:
|
||||
if is_video:
|
||||
# Video: batch all frames into one embed dict, each placeholder gets its frame's tokens
|
||||
all_pixels = torch.cat([img["pixels"] for img in images], dim=0) # [N, H, W, C]
|
||||
img_embeds = [{"type": "video", "data": all_pixels, "max_soft_tokens": images[0]["max_soft_tokens"], "frame_idx": i} for i in range(len(images))]
|
||||
for r in text_tokens:
|
||||
_replace_placeholders(r, 258884, img_embeds)
|
||||
else:
|
||||
img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images]
|
||||
for r in text_tokens:
|
||||
_replace_placeholders(r, 258880, img_embeds)
|
||||
img_token_id = 258884 if is_video else 258880
|
||||
img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images]
|
||||
for r in text_tokens:
|
||||
_replace_placeholders(r, img_token_id, img_embeds)
|
||||
|
||||
if len(audio_features) > 0:
|
||||
aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user