Support Gemma4 12B

This commit is contained in:
kijai 2026-06-05 20:53:38 +03:00
parent 5aa71b9bc2
commit 22f6e40732
3 changed files with 218 additions and 41 deletions

View File

@ -1353,6 +1353,7 @@ class TEModel(Enum):
GEMMA_4_31B = 31 GEMMA_4_31B = 31
T5_GEMMA = 32 T5_GEMMA = 32
GPT_OSS_20B = 33 GPT_OSS_20B = 33
GEMMA_4_12B = 34
def detect_te_model(sd): def detect_te_model(sd):
@ -1382,6 +1383,9 @@ def detect_te_model(sd):
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd: if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B return TEModel.GEMMA_4_31B
# Gemma4 12B Unified: 48 layers, encoder-free; global layers drop v_proj (attention_k_eq_v).
if 'model.layers.47.self_attn.q_norm.weight' in sd and 'model.layers.5.self_attn.v_proj.weight' not in sd:
return TEModel.GEMMA_4_12B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd: if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
@ -1535,10 +1539,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel
clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B, TEModel.GEMMA_4_12B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model] TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B,
TEModel.GEMMA_4_12B: comfy.text_encoders.gemma4.Gemma4_12B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant) clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)

View File

@ -1,6 +1,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio.functional as AF
import torchvision.transforms.functional as TVF
import numpy as np import numpy as np
from tokenizers import Tokenizer
from dataclasses import dataclass from dataclasses import dataclass
import math import math
@ -21,6 +24,10 @@ GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_siz
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
# Encoder-free (gemma4_unified) multimodal embedders: raw patches/waveform projected directly into LM space.
GEMMA4_UNIFIED_VISION_CONFIG = {"model_patch_size": 48, "patch_size": 16, "pooling_kernel_size": 3, "mm_embed_dim": 3840, "mm_posemb_size": 1120, "output_proj_dims": 3840, "rms_norm_eps": 1e-6}
GEMMA4_UNIFIED_AUDIO_CONFIG = {"audio_samples_per_token": 640, "output_proj_dims": 640, "rms_norm_eps": 1e-6}
@dataclass @dataclass
class Gemma4Config: class Gemma4Config:
vocab_size: int = 262144 vocab_size: int = 262144
@ -35,6 +42,9 @@ class Gemma4Config:
transformer_type: str = "gemma4" transformer_type: str = "gemma4"
head_dim = 256 head_dim = 256
global_head_dim = 512 global_head_dim = 512
num_global_key_value_heads = None
attention_k_eq_v = False
vision_bidirectional = False
rms_norm_add = False rms_norm_add = False
mlp_activation = "gelu_pytorch_tanh" mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False qkv_bias = False
@ -72,12 +82,29 @@ class Gemma4_31B_Config(Gemma4Config):
num_hidden_layers: int = 60 num_hidden_layers: int = 60
num_attention_heads: int = 32 num_attention_heads: int = 32
num_key_value_heads: int = 16 num_key_value_heads: int = 16
vision_bidirectional = True
sliding_attention = [1024, 1024, 1024, 1024, 1024, False] sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
hidden_size_per_layer_input: int = 0 hidden_size_per_layer_input: int = 0
num_kv_shared_layers: int = 0 num_kv_shared_layers: int = 0
audio_config = None audio_config = None
vision_config = GEMMA4_VISION_31B_CONFIG vision_config = GEMMA4_VISION_31B_CONFIG
@dataclass
class Gemma4_12B_Config(Gemma4Config):
hidden_size: int = 3840
intermediate_size: int = 15360
num_hidden_layers: int = 48
num_attention_heads: int = 16
num_key_value_heads: int = 8
num_global_key_value_heads = 1
attention_k_eq_v = True
vision_bidirectional = True
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
hidden_size_per_layer_input: int = 0
num_kv_shared_layers: int = 0
audio_config = GEMMA4_UNIFIED_AUDIO_CONFIG
vision_config = GEMMA4_UNIFIED_VISION_CONFIG
# unfused RoPE as addcmul_ RoPE diverges from reference code # unfused RoPE as addcmul_ RoPE diverges from reference code
def _apply_rotary_pos_emb(x, freqs_cis): def _apply_rotary_pos_emb(x, freqs_cis):
@ -89,17 +116,18 @@ def _apply_rotary_pos_emb(x, freqs_cis):
return out return out
class Gemma4Attention(nn.Module): class Gemma4Attention(nn.Module):
def __init__(self, config, head_dim, device=None, dtype=None, ops=None): def __init__(self, config, head_dim, num_kv_heads=None, k_eq_v=False, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads self.num_kv_heads = num_kv_heads if num_kv_heads is not None else config.num_key_value_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = head_dim self.head_dim = head_dim
self.inner_size = self.num_heads * head_dim self.inner_size = self.num_heads * head_dim
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) # k_eq_v: V reuses the K projection (no separate v_proj weight)
self.v_proj = None if k_eq_v else ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None self.q_norm = None
@ -133,7 +161,10 @@ class Gemma4Attention(nn.Module):
shareable_kv = None shareable_kv = None
else: else:
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) if self.v_proj is not None:
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
else:
xv = xk # k_eq_v: V is the raw K projection (before k_norm/RoPE)
if self.k_norm is not None: if self.k_norm is not None:
xk = self.k_norm(xk) xk = self.k_norm(xk)
xv = rms_norm(xv) xv = rms_norm(xv)
@ -186,7 +217,10 @@ class TransformerBlockGemma4(nn.Module):
head_dim = config.head_dim if self.sliding_attention else config.global_head_dim head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) # k_eq_v only on global layers, which then use num_global_key_value_heads
k_eq_v = config.attention_k_eq_v and not self.sliding_attention
num_kv_heads = config.num_global_key_value_heads if k_eq_v else config.num_key_value_heads
self.self_attn = Gemma4Attention(config, head_dim=head_dim, num_kv_heads=num_kv_heads, k_eq_v=k_eq_v, device=device, dtype=dtype, ops=ops)
num_kv_shared = config.num_kv_shared_layers num_kv_shared = config.num_kv_shared_layers
first_kv_shared = config.num_hidden_layers - num_kv_shared first_kv_shared = config.num_hidden_layers - num_kv_shared
@ -203,9 +237,9 @@ class TransformerBlockGemma4(nn.Module):
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
else: # layer_scalar exists on every gemma4 variant, independent of per-layer input
self.layer_scalar = None self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None): def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
sliding_window = None sliding_window = None
@ -245,7 +279,7 @@ class TransformerBlockGemma4(nn.Module):
x = residual + x x = residual + x
if self.layer_scalar is not None: if self.layer_scalar is not None:
x = x * self.layer_scalar x = x * comfy.model_management.cast_to_device(self.layer_scalar, x.device, x.dtype)
return x, present_key_value, shareable_kv return x, present_key_value, shareable_kv
@ -334,6 +368,19 @@ class Gemma4Transformer(nn.Module):
causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val) causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
mask = mask + causal_mask if mask is not None else causal_mask mask = mask + causal_mask if mask is not None else causal_mask
# Bidirectional attention within each image soft-token block (prefill only; text/audio stay causal).
if getattr(self.config, "vision_bidirectional", False) and past_len == 0 and embeds_info:
block_ids = torch.full((seq_len,), -1, dtype=torch.long, device=x.device)
group = 0
for info in embeds_info:
if info.get("type") == "image":
start = info["index"]
block_ids[start:start + info["size"]] = group
group += 1
if group > 0:
same_block = (block_ids[:, None] == block_ids[None, :]) & (block_ids[:, None] >= 0)
mask = mask.masked_fill(same_block, 0.0)
# Per-layer inputs # Per-layer inputs
per_layer_inputs = None per_layer_inputs = None
if self.hidden_size_per_layer_input: if self.hidden_size_per_layer_input:
@ -441,6 +488,28 @@ class Gemma4AudioMixin:
return None, None return None, None
class Gemma4UnifiedBase(Gemma4Base):
"""Encoder-free multimodal Gemma4 (gemma4_unified, e.g. 12B): raw image patches and audio frames projected directly into LM space."""
def _init_model(self, config, dtype, device, operations):
self.num_layers = config.num_hidden_layers
self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.vision_model = Gemma4UnifiedVisionEmbedder(config.vision_config, device=device, dtype=dtype, ops=operations)
self.multi_modal_projector = Gemma4RMSNormProjector(config.vision_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
self.audio_projector = Gemma4RMSNormProjector(config.audio_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
pixels = embed.pop("data").movedim(-1, 1).to(device, dtype=self.dtype) # [B, H, W, C] -> [B, C, H, W], [0,1]
patches, positions = self.vision_model.patchify(pixels)
vision_out = self.vision_model(patches, positions)
return self.multi_modal_projector(vision_out), None
if embed["type"] == "audio":
audio = embed.pop("data").to(device, dtype=self.dtype) # [1, T, audio_samples_per_token]
return self.audio_projector(audio), None
return None, None
# Vision Encoder # Vision Encoder
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
@ -713,6 +782,73 @@ class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
# Encoder-free vision (gemma4_unified): raw merged pixel patches projected directly into LM space.
def _patches_merge(patches, positions_xy, length):
patch_size = math.isqrt(patches.shape[-1] // 3)
k = math.isqrt(patches.shape[-2] // length)
batch = patches.shape[:-2]
max_x = positions_xy[..., 0].max(dim=-1, keepdim=True)[0] + 1
kidx = torch.div(positions_xy, k, rounding_mode="floor")
rem = torch.remainder(positions_xy, k)
order = rem[..., 0] + rem[..., 1] * k + k * k * kidx[..., 0] + k * max_x * kidx[..., 1]
perm = order.long().argsort(dim=-1)
merged = patches.gather(-2, perm.unsqueeze(-1).expand_as(patches))
merged = merged.reshape(*batch, length, k, k, patch_size, patch_size, 3)
merged = merged.permute(*range(len(batch)), -6, -5, -3, -4, -2, -1).reshape(*batch, length, (k * patch_size) ** 2 * 3)
pos = positions_xy.float().gather(-2, perm.unsqueeze(-1).expand_as(positions_xy).long())
pad = (positions_xy == -1).all(dim=-1, keepdim=True)
pos = torch.where(pad, positions_xy.float(), pos).reshape(*batch, length, k * k, 2)
pos = torch.div(pos, k, rounding_mode="floor").min(dim=-2)[0].to(torch.long)
return merged, pos
class Gemma4UnifiedVisionEmbedder(nn.Module):
"""Encoder-free patch embedder (LN -> Dense -> LN -> +2D posemb -> LN); projection to text space is the separate multi_modal_projector."""
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = config["patch_size"]
self.pooling_kernel_size = config["pooling_kernel_size"]
patch_dim = config["model_patch_size"] ** 2 * 3
mm_embed_dim = config["mm_embed_dim"]
self.patch_ln1 = ops.LayerNorm(patch_dim, device=device, dtype=dtype)
self.patch_dense = ops.Linear(patch_dim, mm_embed_dim, device=device, dtype=dtype)
self.patch_ln2 = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
self.pos_embedding = nn.Parameter(torch.empty(config["mm_posemb_size"], 2, mm_embed_dim, device=device, dtype=dtype))
self.pos_norm = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
def patchify(self, pixels):
"""pixels: [B, C, H, W] in [0,1] -> merged patches [B, N, 6912], positions [B, N, 2]."""
ps, k = self.patch_size, self.pooling_kernel_size
out_patches, out_positions = [], []
for img in pixels:
ph, pw = img.shape[-2] // ps, img.shape[-1] // ps
teacher = img.reshape(img.shape[0], ph, ps, pw, ps).permute(1, 3, 2, 4, 0).reshape(ph * pw, -1)
grid = torch.meshgrid(torch.arange(pw, device=img.device), torch.arange(ph, device=img.device), indexing="xy")
tpos = torch.stack(grid, dim=-1).reshape(teacher.shape[0], 2)
n_model = teacher.shape[0] // (k * k)
mp, mpos = _patches_merge(teacher.unsqueeze(0), tpos.unsqueeze(0), n_model)
out_patches.append(mp.squeeze(0))
out_positions.append(mpos.squeeze(0))
return torch.stack(out_patches), torch.stack(out_positions)
def forward(self, pixel_values, image_position_ids):
x = self.patch_ln1(pixel_values)
x = self.patch_dense(x)
x = self.patch_ln2(x)
clamped = image_position_ids.clamp(min=0).long()
valid = (image_position_ids != -1).to(x.dtype).unsqueeze(-1)
axes = torch.arange(2, device=image_position_ids.device)
pos = comfy.model_management.cast_to_device(self.pos_embedding, x.device, x.dtype)
pos_embs = (pos[clamped, axes] * valid).sum(-2)
x = x + pos_embs
return self.pos_norm(x)
# Audio Encoder # Audio Encoder
class Gemma4AudioConvSubsampler(nn.Module): class Gemma4AudioConvSubsampler(nn.Module):
@ -998,25 +1134,35 @@ class Gemma4_Tokenizer():
return {"tokenizer_json": self.tokenizer_json_data} return {"tokenizer_json": self.tokenizer_json_data}
return {} return {}
def _extract_mel_spectrogram(self, waveform, sample_rate): def _audio_token_count(self, num_samples):
"""Extract 128-bin log mel spectrogram. # Default (E2B/E4B): mel frames after two stride-2 conv subsamples.
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. _fl = 320 # int(round(16000 * 20.0 / 1000.0))
""" _hl = 160 # int(round(16000 * 10.0 / 1000.0))
# Mix to mono first, then resample to 16kHz _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
_t = _nmel
for _ in range(2):
_t = (_t + 2 - 3) // 2 + 1
return min(_t, 750)
@staticmethod
def _resample_16k(waveform, sample_rate):
"""Mix to mono and resample to 16kHz. Kaiser params reproduce the reference (transformers
load_audio -> librosa/soxr_hq) to ~1e-12 MSE using only torchaudio."""
if waveform.dim() > 1 and waveform.shape[0] > 1: if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform.mean(dim=0, keepdim=True)
if waveform.dim() == 1: if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) waveform = waveform.unsqueeze(0)
audio = waveform.squeeze(0).float().numpy() audio = waveform.float()
if sample_rate != 16000: if sample_rate != 16000:
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) audio = AF.resample(audio, sample_rate, 16000, resampling_method="sinc_interp_kaiser",
from scipy.signal import resample_poly, firwin lowpass_filter_width=121, rolloff=0.9568384289091556, beta=21.01531462440614)
from math import gcd return audio.squeeze(0).contiguous()
g = gcd(sample_rate, 16000)
up, down = 16000 // g, sample_rate // g def _extract_audio_features(self, waveform, sample_rate):
L = max(up, down) """Default (E2B/E4B): 128-bin log mel spectrogram for the conformer audio encoder.
h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5)) Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
audio = resample_poly(audio, up, down, window=h).astype(np.float32) """
audio = self._resample_16k(waveform, sample_rate).numpy()
n = len(audio) n = len(audio)
# Pad to multiple of 128, build sample-level mask # Pad to multiple of 128, build sample-level mask
@ -1064,8 +1210,8 @@ class Gemma4_Tokenizer():
if audio is not None: if audio is not None:
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) feat, feat_mask = self._extract_audio_features(waveform, sample_rate)
audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) audio_features = [(feat.unsqueeze(0), feat_mask.unsqueeze(0))] # ([1, T, D], [1, T])
# Process image/video frames # Process image/video frames
is_video = video is not None is_video = video is not None
@ -1096,7 +1242,6 @@ class Gemma4_Tokenizer():
target_h = max(int(factor * h // side_mult) * side_mult, side_mult) target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
target_w = max(int(factor * w // side_mult) * side_mult, side_mult) target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
import torchvision.transforms.functional as TVF
for i in range(num_frames): for i in range(num_frames):
# rescaling to match reference code # rescaling to match reference code
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
@ -1115,7 +1260,7 @@ class Gemma4_Tokenizer():
llama_text = llama_template.format(text) llama_text = llama_template.format(text)
else: else:
# Build template from modalities present # Build template from modalities present
system = "<|turn>system\n<|think|><turn|>\n" if thinking else "" system = "<|turn>system\n<|think|>\n<turn|>\n" if thinking else ""
media = "" media = ""
if len(images) > 0: if len(images) > 0:
if is_video: if is_video:
@ -1135,15 +1280,11 @@ class Gemma4_Tokenizer():
if len(audio_features) > 0: if len(audio_features) > 0:
# Compute audio token count (always at 16kHz) # Compute audio token count (always at 16kHz)
num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1] num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
_fl = 320 # int(round(16000 * 20.0 / 1000.0)) n_audio_tokens = self._audio_token_count(num_samples)
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
_t = _nmel
for _ in range(2):
_t = (_t + 2 - 3) // 2 + 1
n_audio_tokens = min(_t, 750)
media += "<|audio>" + "<|audio|>" * n_audio_tokens + "<audio|>" media += "<|audio>" + "<|audio|>" * n_audio_tokens + "<audio|>"
llama_text = f"{system}<|turn>user\n{media}{text}<turn|>\n<|turn>model\n" # Non-thinking mode primes an empty thought channel so the model answers directly.
model_open = "" if thinking else "<|channel>thought\n<channel|>"
llama_text = f"{system}<|turn>user\n{text}{media}<turn|>\n<|turn>model\n{model_open}"
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
@ -1178,7 +1319,6 @@ class Gemma4_Tokenizer():
class _Gemma4Tokenizer: class _Gemma4Tokenizer:
"""Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)""" """Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)"""
def __init__(self, tokenizer_json_bytes=None, **kwargs): def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor): if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@ -1224,6 +1364,30 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class)
class Gemma4UnifiedSDTokenizer(Gemma4SDTokenizer):
"""Encoder-free (gemma4_unified) audio: raw 16kHz waveform frames instead of mel spectrogram."""
embedding_size = 3840
def _extract_audio_features(self, waveform, sample_rate):
audio = self._resample_16k(waveform, sample_rate)
spt = 640 # audio_samples_per_token (40ms at 16kHz)
pad = (-audio.shape[0]) % spt
if pad:
audio = torch.nn.functional.pad(audio, (0, pad))
num_tokens = audio.shape[0] // spt
feats = audio[:num_tokens * spt].reshape(num_tokens, spt)
feats = feats[:750] # audio_seq_length cap (matches reference truncation, ~30s)
mask = torch.ones(feats.shape[0], dtype=torch.bool)
return feats, mask
def _audio_token_count(self, num_samples):
return min((num_samples + 639) // 640, 750)
class Gemma4UnifiedTokenizer(Gemma4Tokenizer):
tokenizer_class = Gemma4UnifiedSDTokenizer
# Model wrappers # Model wrappers
class Gemma4Model(sd1_clip.SDClipModel): class Gemma4Model(sd1_clip.SDClipModel):
model_class = None model_class = None
@ -1256,7 +1420,7 @@ class Gemma4Model(sd1_clip.SDClipModel):
expanded_idx += 1 expanded_idx += 1
initial_token_ids = [ids] initial_token_ids = [ids]
input_ids = torch.tensor(initial_token_ids, device=self.execution_device) input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids, embeds_info=embeds_info)
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None): def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None):
@ -1296,3 +1460,11 @@ def _make_variant(config_cls):
Gemma4_E4B = _make_variant(Gemma4Config) Gemma4_E4B = _make_variant(Gemma4Config)
Gemma4_E2B = _make_variant(Gemma4_E2B_Config) Gemma4_E2B = _make_variant(Gemma4_E2B_Config)
Gemma4_31B = _make_variant(Gemma4_31B_Config) Gemma4_31B = _make_variant(Gemma4_31B_Config)
# Gemma4 12B Unified: encoder-free multimodal, distinct base/tokenizer (not via _make_variant).
class Gemma4_12B(Gemma4UnifiedBase):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(Gemma4_12B_Config(**config_dict), dtype, device, operations)
Gemma4_12B.tokenizer = Gemma4UnifiedTokenizer

View File

@ -860,7 +860,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, embeds_info=None):
device = embeds.device device = embeds.device
if stop_tokens is None: if stop_tokens is None:
@ -887,7 +887,7 @@ class BaseGenerate:
# Generation loop # Generation loop
current_input_ids = initial_input_ids current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"): for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, embeds_info=(embeds_info if step == 0 else None))
logits = self.logits(x)[:, -1] logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item() token_id = next_token[0].item()